Compare commits

...

27 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
5fd393c507 Rename test to TestPingAllByIPRandomClientPort and add to GitHub workflow
Co-authored-by: kradalby <98431+kradalby@users.noreply.github.com>
2025-11-01 08:31:41 +00:00
copilot-swe-agent[bot]
aafe727cb9 Refine direct connection validation logic in TestPingAllByIPDirectConnections
Co-authored-by: kradalby <98431+kradalby@users.noreply.github.com>
2025-11-01 08:22:28 +00:00
copilot-swe-agent[bot]
2ac534dd30 Add TestPingAllByIPDirectConnections test to validate direct connections with randomize_client_port
Co-authored-by: kradalby <98431+kradalby@users.noreply.github.com>
2025-11-01 08:19:37 +00:00
copilot-swe-agent[bot]
c09556dd22 Initial plan 2025-11-01 08:08:53 +00:00
Andrey
f9bb88ad24 expire nodes with a custom timestamp (#2828) 2025-11-01 08:09:13 +01:00
Kristoffer Dalby
456a5d5cce db: ignore _litestream tables when validating (#2843) 2025-11-01 07:08:22 +00:00
Kristoffer Dalby
ddbd3e14ba db: remove all old, unused tables (#2844) 2025-11-01 08:03:37 +01:00
Florian Preinstorfer
0a43aab8f5 Use Debian 12 as minimum version for the deb package 2025-10-28 05:55:26 +01:00
Florian Preinstorfer
4bd614a559 Use current stable base images for Debian and Alpine 2025-10-28 05:55:26 +01:00
Kristoffer Dalby
19a33394f6 changelog: set 0.27 date (#2823) 2025-10-27 12:14:02 +01:00
Kristoffer Dalby
84fe3de251 integration: reduce TestAutoApproveMultiNetwork matrix to 3 tests (#2815) 2025-10-27 11:08:52 +00:00
Paarth Shah
450a7b15ec #2796: Add creation_time and ko_data_creation_time to goreleaser.yml kos 2025-10-27 11:18:57 +01:00
Kristoffer Dalby
64b7142e22 .goreleaser: add upgrade section (#2820) 2025-10-27 10:41:52 +01:00
Kristoffer Dalby
52d27d58f0 hscontrol: add /version HTTP endpoint (#2821) 2025-10-27 10:41:34 +01:00
Kristoffer Dalby
e68e2288f7 gen: test-integration (#2814) 2025-10-24 17:22:53 +02:00
Kristoffer Dalby
c808587de0 cli: do not show new pre-releases on stable (#2813) 2025-10-24 13:15:53 +02:00
Kristoffer Dalby
2bf1200483 policy: fix autogroup:self propagation and optimize cache invalidation (#2807) 2025-10-23 17:57:41 +02:00
Kristoffer Dalby
66826232ff integration: add tests for api bypass (#2811) 2025-10-22 16:30:25 +02:00
Kristoffer Dalby
1cdea7ed9b stricter hostname validation and replace (#2383) 2025-10-22 13:50:39 +02:00
Elyas Asmad
2c9e98d3f5 fix: guard every error statement with early return (#2810) 2025-10-22 13:48:07 +02:00
Florian Preinstorfer
8becb7e54a Mention explicitly that @ is only required in policy 2025-10-21 14:28:03 +02:00
Florian Preinstorfer
ed38d00aaa Fix autogroup:self alternative example
Also indent and split the comment into two lines to avoid horizontal
scrolling.
2025-10-21 14:28:03 +02:00
Florian Preinstorfer
8010cc574e Remove outdated hint about an empty config file 2025-10-19 17:14:15 +02:00
Juanjo Presa
c97d0ff23d Fix fatal error on missing config file by handling viper.ConfigFileNotFoundError
Correctly identify Viper's ConfigFileNotFoundError in LoadConfig to log a warning and use defaults, unifying behavior with empty config files. Fixes fatal error when no config file is present for CLI commands relying on environment variables.
2025-10-19 15:29:47 +02:00
Florian Preinstorfer
047dbda136 Add FAQ on how to disable log submission
Fixes: #2793
2025-10-19 08:24:23 +02:00
Florian Preinstorfer
2a1392fb5b Add healthcheck to container docs 2025-10-19 08:22:30 +02:00
Florian Preinstorfer
46477b8021 Downgrade completed broadcast message to debug 2025-10-18 07:56:59 +02:00
77 changed files with 5746 additions and 2072 deletions

View File

@@ -62,6 +62,7 @@ jobs:
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
- name: Run Integration Test
if: always() && steps.changed-files.outputs.files == 'true'
run:
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=1500 "^${{ inputs.test }}$" \
--timeout=120m \

View File

@@ -24,6 +24,11 @@ jobs:
- TestACLAutogroupMember
- TestACLAutogroupTagged
- TestACLAutogroupSelf
- TestACLPolicyPropagationOverTime
- TestAPIAuthenticationBypass
- TestAPIAuthenticationBypassCurl
- TestGRPCAuthenticationBypass
- TestCLIWithConfigAuthenticationBypass
- TestAuthKeyLogoutAndReloginSameUser
- TestAuthKeyLogoutAndReloginNewUser
- TestAuthKeyLogoutAndReloginSameUserExpiredKey
@@ -32,8 +37,8 @@ jobs:
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser
- TestOIDCReloginSameNodeSameUser
- TestOIDCFollowUpUrl
- TestOIDCReloginSameNodeSameUser
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndReloginSameUser
- TestAuthWebFlowLogoutAndReloginNewUser
@@ -57,6 +62,7 @@ jobs:
- TestDERPServerScenario
- TestDERPServerWebsocketScenario
- TestPingAllByIP
- TestPingAllByIPRandomClientPort
- TestPingAllByIPPublicDERP
- TestEphemeral
- TestEphemeralInAlternateTimezone
@@ -65,6 +71,7 @@ jobs:
- TestTaildrop
- TestUpdateHostnameFromClient
- TestExpireNode
- TestSetNodeExpiryInFuture
- TestNodeOnlineStatus
- TestPingAllByIPManyUpDown
- Test2118DeletingOnlineNodePanics

View File

@@ -8,6 +8,33 @@ before:
release:
prerelease: auto
draft: true
header: |
## Upgrade
Please follow the steps outlined in the [upgrade guide](https://headscale.net/stable/setup/upgrade/) to update your existing Headscale installation.
**It's best to update from one stable version to the next** (e.g., 0.24.0 → 0.25.1 → 0.26.1) in case you are multiple releases behind. You should always pick the latest available patch release.
Be sure to check the changelog above for version-specific upgrade instructions and breaking changes.
### Backup Your Database
**Always backup your database before upgrading.** Here's how to backup a SQLite database:
```bash
# Stop headscale
systemctl stop headscale
# Backup sqlite database
cp /var/lib/headscale/db.sqlite /var/lib/headscale/db.sqlite.backup
# Backup sqlite WAL/SHM files (if they exist)
cp /var/lib/headscale/db.sqlite-wal /var/lib/headscale/db.sqlite-wal.backup
cp /var/lib/headscale/db.sqlite-shm /var/lib/headscale/db.sqlite-shm.backup
# Start headscale (migration will run automatically)
systemctl start headscale
```
builds:
- id: headscale
@@ -118,6 +145,8 @@ kos:
- "{{ .Tag }}"
- '{{ trimprefix .Tag "v" }}'
- "sha-{{ .ShortCommit }}"
creation_time: "{{.CommitTimestamp}}"
ko_data_creation_time: "{{.CommitTimestamp}}"
- id: ghcr-debug
repositories:

View File

@@ -2,7 +2,12 @@
## Next
## 0.27.0 (2025-xx-xx)
### Changes
- Expire nodes with a custom timestamp
[#2828](https://github.com/juanfont/headscale/pull/2828)
## 0.27.0 (2025-10-27)
**Minimum supported Tailscale client version: v1.64.0**
@@ -84,6 +89,20 @@ the code base over time and make it more correct and efficient.
[#2692](https://github.com/juanfont/headscale/pull/2692)
- Policy: Zero or empty destination port is no longer allowed
[#2606](https://github.com/juanfont/headscale/pull/2606)
- Stricter hostname validation [#2383](https://github.com/juanfont/headscale/pull/2383)
- Hostnames must be valid DNS labels (2-63 characters, alphanumeric and
hyphens only, cannot start/end with hyphen)
- **Client Registration (New Nodes)**: Invalid hostnames are automatically
renamed to `invalid-XXXXXX` format
- `my-laptop` → accepted as-is
- `My-Laptop``my-laptop` (lowercased)
- `my_laptop``invalid-a1b2c3` (underscore not allowed)
- `test@host``invalid-d4e5f6` (@ not allowed)
- `laptop-🚀``invalid-j1k2l3` (emoji not allowed)
- **Hostinfo Updates / CLI**: Invalid hostnames are rejected with an error
- Valid names are accepted or lowercased
- Names with invalid characters, too short (<2), too long (>63), or
starting/ending with hyphen are rejected
### Changes
@@ -192,7 +211,7 @@ new policy code passes all of our tests.
- Error messages should be more descriptive and informative.
- There is still work to be here, but it is already improved with "typing"
(e.g. only Users can be put in Groups)
- All users must contain an `@` character.
- All users in the policy must contain an `@` character.
- If your user naturally contains and `@`, like an email, this will just work.
- If its based on usernames, or other identifiers not containing an `@`, an
`@` should be appended at the end. For example, if your user is `john`, it

View File

@@ -528,3 +528,4 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) {
- **Integration Tests**: Require Docker and can consume significant disk space - use headscale-integration-tester agent
- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management
- **Quality Assurance**: Always use appropriate specialized agents for testing and validation tasks
- **NEVER create gists in the user's name**: Do not use the `create_gist` tool - present information directly in the response instead

View File

@@ -12,7 +12,7 @@ WORKDIR /go/src/tailscale
ARG TARGETARCH
RUN GOARCH=$TARGETARCH go install -v ./cmd/derper
FROM alpine:3.18
FROM alpine:3.22
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
COPY --from=build-env /go/bin/* /usr/local/bin/

View File

@@ -2,13 +2,12 @@
# and are in no way endorsed by Headscale's maintainers as an
# official nor supported release or distribution.
FROM docker.io/golang:1.25-bookworm
FROM docker.io/golang:1.25-trixie
ARG VERSION=dev
ENV GOPATH /go
WORKDIR /go/src/headscale
RUN apt-get update \
&& apt-get install --no-install-recommends --yes less jq sqlite3 dnsutils \
RUN apt-get --update install --no-install-recommends --yes less jq sqlite3 dnsutils \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN mkdir -p /var/run/headscale

View File

@@ -36,7 +36,7 @@ RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \
-v ./cmd/tailscale ./cmd/tailscaled ./cmd/containerboot
FROM alpine:3.18
FROM alpine:3.22
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
COPY --from=build-env /go/bin/* /usr/local/bin/

View File

@@ -15,6 +15,7 @@ import (
"github.com/samber/lo"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/types/key"
)
@@ -51,6 +52,7 @@ func init() {
nodeCmd.AddCommand(registerNodeCmd)
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
err = expireNodeCmd.MarkFlagRequired("identifier")
if err != nil {
log.Fatal(err.Error())
@@ -289,12 +291,37 @@ var expireNodeCmd = &cobra.Command{
)
}
expiry, err := cmd.Flags().GetString("expiry")
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error converting expiry to string: %s", err),
output,
)
return
}
expiryTime := time.Now()
if expiry != "" {
expiryTime, err = time.Parse(time.RFC3339, expiry)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error converting expiry to string: %s", err),
output,
)
return
}
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.ExpireNodeRequest{
NodeId: identifier,
Expiry: timestamppb.New(expiryTime),
}
response, err := client.ExpireNode(ctx, request)

View File

@@ -5,6 +5,7 @@ import (
"os"
"runtime"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog"
@@ -75,8 +76,9 @@ func initConfig() {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
!versionInfo.Dirty {
githubTag := &latest.GithubTag{
Owner: "juanfont",
Repository: "headscale",
Owner: "juanfont",
Repository: "headscale",
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
}
res, err := latest.Check(githubTag, versionInfo.Version)
if err == nil && res.Outdated {
@@ -91,6 +93,43 @@ func initConfig() {
}
}
var prereleases = []string{"alpha", "beta", "rc", "dev"}
func isPreReleaseVersion(version string) bool {
for _, unstable := range prereleases {
if strings.Contains(version, unstable) {
return true
}
}
return false
}
// filterPreReleasesIfStable returns a function that filters out
// pre-release tags if the current version is stable.
// If the current version is a pre-release, it does not filter anything.
// versionFunc is a function that returns the current version string, it is
// a func for testability.
func filterPreReleasesIfStable(versionFunc func() string) func(string) bool {
return func(tag string) bool {
version := versionFunc()
// If we are on a pre-release version, then we do not filter anything
// as we want to recommend the user the latest pre-release.
if isPreReleaseVersion(version) {
return false
}
// If we are on a stable release, filter out pre-releases.
for _, ignore := range prereleases {
if strings.Contains(tag, ignore) {
return true
}
}
return false
}
}
var rootCmd = &cobra.Command{
Use: "headscale",
Short: "headscale - a Tailscale control server",

View File

@@ -0,0 +1,293 @@
package cli
import (
"testing"
)
func TestFilterPreReleasesIfStable(t *testing.T) {
tests := []struct {
name string
currentVersion string
tag string
expectedFilter bool
description string
}{
{
name: "stable version filters alpha tag",
currentVersion: "0.23.0",
tag: "v0.24.0-alpha.1",
expectedFilter: true,
description: "When on stable release, alpha tags should be filtered",
},
{
name: "stable version filters beta tag",
currentVersion: "0.23.0",
tag: "v0.24.0-beta.2",
expectedFilter: true,
description: "When on stable release, beta tags should be filtered",
},
{
name: "stable version filters rc tag",
currentVersion: "0.23.0",
tag: "v0.24.0-rc.1",
expectedFilter: true,
description: "When on stable release, rc tags should be filtered",
},
{
name: "stable version allows stable tag",
currentVersion: "0.23.0",
tag: "v0.24.0",
expectedFilter: false,
description: "When on stable release, stable tags should not be filtered",
},
{
name: "alpha version allows alpha tag",
currentVersion: "0.23.0-alpha.1",
tag: "v0.24.0-alpha.2",
expectedFilter: false,
description: "When on alpha release, alpha tags should not be filtered",
},
{
name: "alpha version allows beta tag",
currentVersion: "0.23.0-alpha.1",
tag: "v0.24.0-beta.1",
expectedFilter: false,
description: "When on alpha release, beta tags should not be filtered",
},
{
name: "alpha version allows rc tag",
currentVersion: "0.23.0-alpha.1",
tag: "v0.24.0-rc.1",
expectedFilter: false,
description: "When on alpha release, rc tags should not be filtered",
},
{
name: "alpha version allows stable tag",
currentVersion: "0.23.0-alpha.1",
tag: "v0.24.0",
expectedFilter: false,
description: "When on alpha release, stable tags should not be filtered",
},
{
name: "beta version allows alpha tag",
currentVersion: "0.23.0-beta.1",
tag: "v0.24.0-alpha.1",
expectedFilter: false,
description: "When on beta release, alpha tags should not be filtered",
},
{
name: "beta version allows beta tag",
currentVersion: "0.23.0-beta.2",
tag: "v0.24.0-beta.3",
expectedFilter: false,
description: "When on beta release, beta tags should not be filtered",
},
{
name: "beta version allows rc tag",
currentVersion: "0.23.0-beta.1",
tag: "v0.24.0-rc.1",
expectedFilter: false,
description: "When on beta release, rc tags should not be filtered",
},
{
name: "beta version allows stable tag",
currentVersion: "0.23.0-beta.1",
tag: "v0.24.0",
expectedFilter: false,
description: "When on beta release, stable tags should not be filtered",
},
{
name: "rc version allows alpha tag",
currentVersion: "0.23.0-rc.1",
tag: "v0.24.0-alpha.1",
expectedFilter: false,
description: "When on rc release, alpha tags should not be filtered",
},
{
name: "rc version allows beta tag",
currentVersion: "0.23.0-rc.1",
tag: "v0.24.0-beta.1",
expectedFilter: false,
description: "When on rc release, beta tags should not be filtered",
},
{
name: "rc version allows rc tag",
currentVersion: "0.23.0-rc.2",
tag: "v0.24.0-rc.3",
expectedFilter: false,
description: "When on rc release, rc tags should not be filtered",
},
{
name: "rc version allows stable tag",
currentVersion: "0.23.0-rc.1",
tag: "v0.24.0",
expectedFilter: false,
description: "When on rc release, stable tags should not be filtered",
},
{
name: "stable version with patch filters alpha",
currentVersion: "0.23.1",
tag: "v0.24.0-alpha.1",
expectedFilter: true,
description: "Stable version with patch number should filter alpha tags",
},
{
name: "stable version with patch allows stable",
currentVersion: "0.23.1",
tag: "v0.24.0",
expectedFilter: false,
description: "Stable version with patch number should allow stable tags",
},
{
name: "tag with alpha substring in version number",
currentVersion: "0.23.0",
tag: "v1.0.0-alpha.1",
expectedFilter: true,
description: "Tags with alpha in version string should be filtered on stable",
},
{
name: "tag with beta substring in version number",
currentVersion: "0.23.0",
tag: "v1.0.0-beta.1",
expectedFilter: true,
description: "Tags with beta in version string should be filtered on stable",
},
{
name: "tag with rc substring in version number",
currentVersion: "0.23.0",
tag: "v1.0.0-rc.1",
expectedFilter: true,
description: "Tags with rc in version string should be filtered on stable",
},
{
name: "empty tag on stable version",
currentVersion: "0.23.0",
tag: "",
expectedFilter: false,
description: "Empty tags should not be filtered",
},
{
name: "dev version allows all tags",
currentVersion: "0.23.0-dev",
tag: "v0.24.0-alpha.1",
expectedFilter: false,
description: "Dev versions should not filter any tags (pre-release allows all)",
},
{
name: "stable version filters dev tag",
currentVersion: "0.23.0",
tag: "v0.24.0-dev",
expectedFilter: true,
description: "When on stable release, dev tags should be filtered",
},
{
name: "dev version allows dev tag",
currentVersion: "0.23.0-dev",
tag: "v0.24.0-dev.1",
expectedFilter: false,
description: "When on dev release, dev tags should not be filtered",
},
{
name: "dev version allows stable tag",
currentVersion: "0.23.0-dev",
tag: "v0.24.0",
expectedFilter: false,
description: "When on dev release, stable tags should not be filtered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterPreReleasesIfStable(func() string { return tt.currentVersion })(tt.tag)
if result != tt.expectedFilter {
t.Errorf("%s: got %v, want %v\nDescription: %s\nCurrent version: %s, Tag: %s",
tt.name,
result,
tt.expectedFilter,
tt.description,
tt.currentVersion,
tt.tag,
)
}
})
}
}
func TestIsPreReleaseVersion(t *testing.T) {
tests := []struct {
name string
version string
expected bool
description string
}{
{
name: "stable version",
version: "0.23.0",
expected: false,
description: "Stable version should not be pre-release",
},
{
name: "alpha version",
version: "0.23.0-alpha.1",
expected: true,
description: "Alpha version should be pre-release",
},
{
name: "beta version",
version: "0.23.0-beta.1",
expected: true,
description: "Beta version should be pre-release",
},
{
name: "rc version",
version: "0.23.0-rc.1",
expected: true,
description: "RC version should be pre-release",
},
{
name: "version with alpha substring",
version: "0.23.0-alphabetical",
expected: true,
description: "Version containing 'alpha' should be pre-release",
},
{
name: "version with beta substring",
version: "0.23.0-betamax",
expected: true,
description: "Version containing 'beta' should be pre-release",
},
{
name: "dev version",
version: "0.23.0-dev",
expected: true,
description: "Dev version should be pre-release",
},
{
name: "empty version",
version: "",
expected: false,
description: "Empty version should not be pre-release",
},
{
name: "version with patch number",
version: "0.23.1",
expected: false,
description: "Stable version with patch should not be pre-release",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isPreReleaseVersion(tt.version)
if result != tt.expected {
t.Errorf("%s: got %v, want %v\nDescription: %s\nVersion: %s",
tt.name,
result,
tt.expected,
tt.description,
tt.version,
)
}
})
}
}

View File

@@ -81,7 +81,7 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err)
}
// Create file
outFile, err := os.Create(targetPath)
if err != nil {

View File

@@ -393,11 +393,13 @@ unix_socket_permission: "0770"
# method: S256
# Logtail configuration
# Logtail is Tailscales logging and auditing infrastructure, it allows the control panel
# to instruct tailscale nodes to log their activity to a remote server.
# Logtail is Tailscales logging and auditing infrastructure, it allows the
# control panel to instruct tailscale nodes to log their activity to a remote
# server. To disable logging on the client side, please refer to:
# https://tailscale.com/kb/1011/log-mesh-traffic#opting-out-of-client-logging
logtail:
# Enable logtail for this headscales clients.
# As there is currently no support for overriding the log server in headscale, this is
# Enable logtail for tailscale nodes of this Headscale instance.
# As there is currently no support for overriding the log server in Headscale, this is
# disabled by default. Enabling this will make your clients send logs to Tailscale Inc.
enabled: false

View File

@@ -159,3 +159,19 @@ indicates which part of the policy is invalid. Follow these steps to fix your po
The above commands to get/set the policy require a complete server configuration file including database settings. A
minimal config to [control Headscale via remote CLI](../ref/remote-cli.md) is not sufficient. You may use `headscale
-c /path/to/config.yaml` to specify the path to an alternative configuration file.
## How can I avoid to send logs to Tailscale Inc?
A Tailscale client [collects logs about its operation and connection attempts with other
clients](https://tailscale.com/kb/1011/log-mesh-traffic#client-logs) and sends them to a central log service operated by
Tailscale Inc.
Headscale, by default, instructs clients to disable log submission to the central log service. This configuration is
applied by a client once it successfully connected with Headscale. See the configuration option `logtail.enabled` in the
[configuration file](../ref/configuration.md) for details.
Alternatively, logging can also be disabled on the client side. This is independent of Headscale and opting out of
client logging disables log submission early during client startup. The configuration is operating system specific and
is usually achieved by setting the environment variable `TS_NO_LOGS_NO_SUPPORT=true` or by passing the flag
`--no-logs-no-support` to `tailscaled`. See
<https://tailscale.com/kb/1011/log-mesh-traffic#opting-out-of-client-logging> for details.

View File

@@ -210,7 +210,7 @@ Headscale supports several autogroups that automatically include users, destinat
### `autogroup:internet`
Allows access to the internet through [exit nodes](routes.md#exit-node). Can only be used in ACL destinations.
Allows access to the internet through [exit nodes](routes.md#exit-node). Can only be used in ACL destinations.
```json
{
@@ -244,10 +244,10 @@ Includes all devices that have at least one tag.
}
```
### `autogroup:self`
### `autogroup:self`
**(EXPERIMENTAL)**
!!! warning "The current implementation of `autogroup:self` is inefficient"
!!! warning "The current implementation of `autogroup:self` is inefficient"
Includes devices where the same user is authenticated on both the source and destination. Does not include tagged devices. Can only be used in ACL destinations.
@@ -260,15 +260,16 @@ Includes devices where the same user is authenticated on both the source and des
```
*Using `autogroup:self` may cause performance degradation on the Headscale coordinator server in large deployments, as filter rules must be compiled per-node rather than globally and the current implementation is not very efficient.*
If you experience performance issues, consider using more specific ACL rules or limiting the use of `autogroup:self`.
```json
If you experience performance issues, consider using more specific ACL rules or limiting the use of `autogroup:self`.
```json
{
// To allow internal users communications to their own nodes we can do following rules to allow access in case autogroup:self is causing performance issues.
{ "action": "accept", "src": ["boss@"], "dst": ["boss@:"] },
{ "action": "accept", "src": ["dev1@"], "dst": ["dev1@:*"] },
{ "action": "accept", "src": ["dev2@"], "dst": ["dev2@:"] },
{ "action": "accept", "src": ["admin1@"], "dst": ["admin1@:"] },
{ "action": "accept", "src": ["intern1@"], "dst": ["intern1@:"] }
// The following rules allow internal users to communicate with their
// own nodes in case autogroup:self is causing performance issues.
{ "action": "accept", "src": ["boss@"], "dst": ["boss@:*"] },
{ "action": "accept", "src": ["dev1@"], "dst": ["dev1@:*"] },
{ "action": "accept", "src": ["dev2@"], "dst": ["dev2@:*"] },
{ "action": "accept", "src": ["admin1@"], "dst": ["admin1@:*"] },
{ "action": "accept", "src": ["intern1@"], "dst": ["intern1@:*"] }
}
```

View File

@@ -67,12 +67,6 @@ headscale apikeys expire --prefix "<PREFIX>"
export HEADSCALE_CLI_API_KEY="<API_KEY_FROM_PREVIOUS_STEP>"
```
!!! bug
Headscale currently requires at least an empty configuration file when environment variables are used to
specify connection details. See [issue 2193](https://github.com/juanfont/headscale/issues/2193) for more
information.
This instructs the `headscale` binary to connect to a remote instance at `<HEADSCALE_ADDRESS>:<PORT>`, instead of
connecting to the local instance.

View File

@@ -39,6 +39,7 @@ Registry](https://github.com/juanfont/headscale/pkgs/container/headscale). The c
--volume "$(pwd)/run:/var/run/headscale" \
--publish 127.0.0.1:8080:8080 \
--publish 127.0.0.1:9090:9090 \
--health-cmd "CMD headscale health" \
docker.io/headscale/headscale:<VERSION> \
serve
```
@@ -66,6 +67,8 @@ Registry](https://github.com/juanfont/headscale/pkgs/container/headscale). The c
- <HEADSCALE_PATH>/lib:/var/lib/headscale
- <HEADSCALE_PATH>/run:/var/run/headscale
command: serve
healthcheck:
test: ["CMD", "headscale", "health"]
```
1. Verify headscale is running:

View File

@@ -7,7 +7,7 @@ Both are available on the [GitHub releases page](https://github.com/juanfont/hea
It is recommended to use our DEB packages to install headscale on a Debian based system as those packages configure a
local user to run headscale, provide a default configuration and ship with a systemd service file. Supported
distributions are Ubuntu 22.04 or newer, Debian 11 or newer.
distributions are Ubuntu 22.04 or newer, Debian 12 or newer.
1. Download the [latest headscale package](https://github.com/juanfont/headscale/releases/latest) for your platform (`.deb` for Ubuntu and Debian).

View File

@@ -19,7 +19,7 @@
overlay = _: prev: let
pkgs = nixpkgs.legacyPackages.${prev.system};
buildGo = pkgs.buildGo125Module;
vendorHash = "sha256-GUIzlPRsyEq1uSTzRNds9p1uVu4pTeH5PAxrJ5Njhis=";
vendorHash = "sha256-VOi4PGZ8I+2MiwtzxpKc/4smsL5KcH/pHVkjJfAFPJ0=";
in {
headscale = buildGo {
pname = "headscale";

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/apikey.proto

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/device.proto

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/headscale.proto

View File

@@ -471,6 +471,8 @@ func local_request_HeadscaleService_DeleteNode_0(ctx context.Context, marshaler
return msg, metadata, err
}
var filter_HeadscaleService_ExpireNode_0 = &utilities.DoubleArray{Encoding: map[string]int{"node_id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}}
func request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq ExpireNodeRequest
@@ -485,6 +487,12 @@ func request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler runtim
if err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err)
}
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ExpireNode_0); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := client.ExpireNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
@@ -503,6 +511,12 @@ func local_request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler
if err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err)
}
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ExpireNode_0); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.ExpireNode(ctx, &protoReq)
return msg, metadata, err
}

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/node.proto
@@ -729,6 +729,7 @@ func (*DeleteNodeResponse) Descriptor() ([]byte, []int) {
type ExpireNodeRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"`
Expiry *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiry,proto3" json:"expiry,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -770,6 +771,13 @@ func (x *ExpireNodeRequest) GetNodeId() uint64 {
return 0
}
func (x *ExpireNodeRequest) GetExpiry() *timestamppb.Timestamp {
if x != nil {
return x.Expiry
}
return nil
}
type ExpireNodeResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"`
@@ -1349,9 +1357,10 @@ const file_headscale_v1_node_proto_rawDesc = "" +
"\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\",\n" +
"\x11DeleteNodeRequest\x12\x17\n" +
"\anode_id\x18\x01 \x01(\x04R\x06nodeId\"\x14\n" +
"\x12DeleteNodeResponse\",\n" +
"\x12DeleteNodeResponse\"`\n" +
"\x11ExpireNodeRequest\x12\x17\n" +
"\anode_id\x18\x01 \x01(\x04R\x06nodeId\"<\n" +
"\anode_id\x18\x01 \x01(\x04R\x06nodeId\x122\n" +
"\x06expiry\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x06expiry\"<\n" +
"\x12ExpireNodeResponse\x12&\n" +
"\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"G\n" +
"\x11RenameNodeRequest\x12\x17\n" +
@@ -1439,16 +1448,17 @@ var file_headscale_v1_node_proto_depIdxs = []int32{
1, // 7: headscale.v1.GetNodeResponse.node:type_name -> headscale.v1.Node
1, // 8: headscale.v1.SetTagsResponse.node:type_name -> headscale.v1.Node
1, // 9: headscale.v1.SetApprovedRoutesResponse.node:type_name -> headscale.v1.Node
1, // 10: headscale.v1.ExpireNodeResponse.node:type_name -> headscale.v1.Node
1, // 11: headscale.v1.RenameNodeResponse.node:type_name -> headscale.v1.Node
1, // 12: headscale.v1.ListNodesResponse.nodes:type_name -> headscale.v1.Node
1, // 13: headscale.v1.MoveNodeResponse.node:type_name -> headscale.v1.Node
1, // 14: headscale.v1.DebugCreateNodeResponse.node:type_name -> headscale.v1.Node
15, // [15:15] is the sub-list for method output_type
15, // [15:15] is the sub-list for method input_type
15, // [15:15] is the sub-list for extension type_name
15, // [15:15] is the sub-list for extension extendee
0, // [0:15] is the sub-list for field type_name
25, // 10: headscale.v1.ExpireNodeRequest.expiry:type_name -> google.protobuf.Timestamp
1, // 11: headscale.v1.ExpireNodeResponse.node:type_name -> headscale.v1.Node
1, // 12: headscale.v1.RenameNodeResponse.node:type_name -> headscale.v1.Node
1, // 13: headscale.v1.ListNodesResponse.nodes:type_name -> headscale.v1.Node
1, // 14: headscale.v1.MoveNodeResponse.node:type_name -> headscale.v1.Node
1, // 15: headscale.v1.DebugCreateNodeResponse.node:type_name -> headscale.v1.Node
16, // [16:16] is the sub-list for method output_type
16, // [16:16] is the sub-list for method input_type
16, // [16:16] is the sub-list for extension type_name
16, // [16:16] is the sub-list for extension extendee
0, // [0:16] is the sub-list for field type_name
}
func init() { file_headscale_v1_node_proto_init() }

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/policy.proto

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/preauthkey.proto

View File

@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: headscale/v1/user.proto

View File

@@ -406,6 +406,13 @@
"required": true,
"type": "string",
"format": "uint64"
},
{
"name": "expiry",
"in": "query",
"required": false,
"type": "string",
"format": "date-time"
}
],
"tags": [

6
go.mod
View File

@@ -36,7 +36,7 @@ require (
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/tailscale/hujson v0.0.0-20250226034555-ec1d1c113d33
github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694
github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
@@ -115,7 +115,7 @@ require (
github.com/containerd/errdefs v0.3.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect
github.com/creachadair/mds v0.25.2 // indirect
github.com/creachadair/mds v0.25.10 // indirect
github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0 // indirect
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e // indirect
github.com/distribution/reference v0.6.0 // indirect
@@ -159,7 +159,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lib/pq v1.10.9 // indirect

6
go.sum
View File

@@ -126,6 +126,8 @@ github.com/creachadair/flax v0.0.5 h1:zt+CRuXQASxwQ68e9GHAOnEgAU29nF0zYMHOCrL5wz
github.com/creachadair/flax v0.0.5/go.mod h1:F1PML0JZLXSNDMNiRGK2yjm5f+L9QCHchyHBldFymj8=
github.com/creachadair/mds v0.25.2 h1:xc0S0AfDq5GX9KUR5sLvi5XjA61/P6S5e0xFs1vA18Q=
github.com/creachadair/mds v0.25.2/go.mod h1:+s4CFteFRj4eq2KcGHW8Wei3u9NyzSPzNV32EvjyK/Q=
github.com/creachadair/mds v0.25.10 h1:9k9JB35D1xhOCFl0liBhagBBp8fWWkKZrA7UXsfoHtA=
github.com/creachadair/mds v0.25.10/go.mod h1:4hatI3hRM+qhzuAmqPRFvaBM8mONkS7nsLxkcuTYUIs=
github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc=
github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
@@ -278,6 +280,8 @@ github.com/jsimonetti/rtnetlink v1.4.1 h1:JfD4jthWBqZMEffc5RjgmlzpYttAVw1sdnmiNa
github.com/jsimonetti/rtnetlink v1.4.1/go.mod h1:xJjT7t59UIZ62GLZbv6PLLo8VFrostJMPBAheR6OM8w=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
@@ -461,6 +465,8 @@ github.com/tailscale/setec v0.0.0-20250305161714-445cadbbca3d h1:mnqtPWYyvNiPU9l
github.com/tailscale/setec v0.0.0-20250305161714-445cadbbca3d/go.mod h1:9BzmlFc3OLqLzLTF/5AY+BMs+clxMqyhSGzgXIm8mNI=
github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694 h1:95eIP97c88cqAFU/8nURjgI9xxPbD+Ci6mY/a79BI/w=
github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694/go.mod h1:veguaG8tVg1H/JG5RfpoUW41I+O8ClPElo/fTYr8mMk=
github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993 h1:FyiiAvDAxpB0DrW2GW3KOVfi3YFOtsQUEeFWbf55JJU=
github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993/go.mod h1:xJkMmR3t+thnUQhA3Q4m2VSlS5pcOq+CIjmU/xfKKx4=
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97 h1:JJkDnrAhHvOCttk8z9xeZzcDlzzkRA7+Duxj9cwOyxk=
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97/go.mod h1:9jS8HxwsP2fU4ESZ7DZL+fpH/U66EVlVMzdgznH12RM=
github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14=

View File

@@ -380,53 +380,45 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writer http.ResponseWriter,
req *http.Request,
) {
if err := func() error {
log.Trace().
Caller().
Str("client_address", req.RemoteAddr).
Msg("HTTP authentication invoked")
log.Trace().
Caller().
Str("client_address", req.RemoteAddr).
Msg("HTTP authentication invoked")
authHeader := req.Header.Get("Authorization")
authHeader := req.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller().
Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`)
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
return err
writeUnauthorized := func(statusCode int) {
writer.WriteHeader(statusCode)
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
log.Error().Err(err).Msg("writing HTTP response failed")
}
}
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", req.RemoteAddr).
Msg("failed to validate token")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Unauthorized"))
return err
}
if !valid {
log.Info().
Str("client_address", req.RemoteAddr).
Msg("invalid token")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
return err
}
return nil
}(); err != nil {
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller().
Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`)
writeUnauthorized(http.StatusUnauthorized)
return
}
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix))
if err != nil {
log.Info().
Caller().
Err(err).
Msg("Failed to write HTTP response")
Str("client_address", req.RemoteAddr).
Msg("failed to validate token")
writeUnauthorized(http.StatusUnauthorized)
return
}
if !valid {
log.Info().
Str("client_address", req.RemoteAddr).
Msg("invalid token")
writeUnauthorized(http.StatusUnauthorized)
return
}
@@ -454,6 +446,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet)
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).
Methods(http.MethodGet)

View File

@@ -1,6 +1,7 @@
package hscontrol
import (
"cmp"
"context"
"errors"
"fmt"
@@ -283,19 +284,23 @@ func (h *Headscale) reqToNewRegisterResponse(
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
// Ensure we have valid hostinfo and hostname
validHostinfo, hostname := util.EnsureValidHostinfo(
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo,
machineKey.String(),
req.NodeKey.String(),
)
// Ensure we have valid hostinfo
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: validHostinfo,
Hostinfo: hostinfo,
LastSeen: ptr.To(time.Now()),
},
)
@@ -396,13 +401,15 @@ func (h *Headscale) handleRegisterInteractive(
return nil, fmt.Errorf("generating registration ID: %w", err)
}
// Ensure we have valid hostinfo and hostname
validHostinfo, hostname := util.EnsureValidHostinfo(
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo,
machineKey.String(),
req.NodeKey.String(),
)
// Ensure we have valid hostinfo
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
if req.Hostinfo == nil {
log.Warn().
Str("machine.key", machineKey.ShortString()).
@@ -416,13 +423,14 @@ func (h *Headscale) handleRegisterInteractive(
Str("generated.hostname", hostname).
Msg("Received registration request with empty hostname, generated default")
}
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: validHostinfo,
Hostinfo: hostinfo,
LastSeen: ptr.To(time.Now()),
},
)

View File

@@ -1,6 +1,6 @@
package capver
//Generated DO NOT EDIT
// Generated DO NOT EDIT
import "tailscale.com/tailcfg"
@@ -37,16 +37,15 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.84.2": 116,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
}

View File

@@ -932,6 +932,26 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
// Drop all tables that are no longer in use and has existed.
// They potentially still present from broken migrations in the past.
ID: "202510311551",
Migrate: func(tx *gorm.DB) error {
for _, oldTable := range []string{"namespaces", "machines", "shared_machines", "kvs", "pre_auth_key_acl_tags", "routes"} {
err := tx.Migrator().DropTable(oldTable)
if err != nil {
log.Trace().Str("table", oldTable).
Err(err).
Msg("Error dropping old table, continuing...")
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
return nil
},
},
// From this point, the following rules must be followed:
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
@@ -962,7 +982,17 @@ AND auth_key_id NOT IN (
ctx, cancel := context.WithTimeout(context.Background(), contextTimeoutSecs*time.Second)
defer cancel()
if err := squibble.Validate(ctx, sqlConn, dbSchema); err != nil {
opts := squibble.DigestOptions{
IgnoreTables: []string{
// Litestream tables, these are inserted by
// litestream and not part of our schema
// https://litestream.io/how-it-works
"_litestream_lock",
"_litestream_seq",
},
}
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
return nil, fmt.Errorf("validating schema: %w", err)
}
}

View File

@@ -5,9 +5,11 @@ import (
"errors"
"fmt"
"net/netip"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
@@ -25,6 +27,8 @@ const (
NodeGivenNameTrimSize = 2
)
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var (
ErrNodeNotFound = errors.New("node not found")
ErrNodeRouteIsNotAvailable = errors.New("route is not available on node")
@@ -259,6 +263,10 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
func RenameNode(tx *gorm.DB,
nodeID types.NodeID, newName string,
) error {
if err := util.ValidateHostname(newName); err != nil {
return fmt.Errorf("renaming node: %w", err)
}
// Check if the new name is unique
var count int64
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
@@ -376,6 +384,14 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
node.IPv4 = ipv4
node.IPv6 = ipv6
var err error
node.Hostname, err = util.NormaliseHostname(node.Hostname)
if err != nil {
newHostname := util.InvalidString()
log.Info().Err(err).Str("invalid-hostname", node.Hostname).Str("new-hostname", newHostname).Msgf("Invalid hostname, replacing")
node.Hostname = newHostname
}
if node.GivenName == "" {
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
if err != nil {
@@ -432,7 +448,10 @@ func NodeSave(tx *gorm.DB, node *types.Node) error {
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
suppliedName = util.ConvertWithFQDNRules(suppliedName)
// Strip invalid DNS characters for givenName
suppliedName = strings.ToLower(suppliedName)
suppliedName = invalidDNSRegex.ReplaceAllString(suppliedName, "")
if len(suppliedName) > util.LabelHostnameLength {
return "", types.ErrHostnameTooLong
}

View File

@@ -640,7 +640,7 @@ func TestListEphemeralNodes(t *testing.T) {
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
}
func TestRenameNode(t *testing.T) {
func TestNodeNaming(t *testing.T) {
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
@@ -672,6 +672,26 @@ func TestRenameNode(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{},
}
// Using non-ASCII characters in the hostname can
// break your network, so they should be replaced when registering
// a node.
// https://github.com/juanfont/headscale/issues/2343
nodeInvalidHostname := types.Node{
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "我的电脑",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
nodeShortHostname := types.Node{
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "a",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
err = db.DB.Save(&node).Error
require.NoError(t, err)
@@ -684,7 +704,11 @@ func TestRenameNode(t *testing.T) {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
return err
})
require.NoError(t, err)
@@ -692,10 +716,12 @@ func TestRenameNode(t *testing.T) {
nodes, err := db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Len(t, nodes, 4)
t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName)
t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName)
t.Logf("node3 %s %s", nodes[2].Hostname, nodes[2].GivenName)
t.Logf("node4 %s %s", nodes[3].Hostname, nodes[3].GivenName)
assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName)
assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName)
@@ -707,6 +733,10 @@ func TestRenameNode(t *testing.T) {
assert.Len(t, nodes[1].Hostname, 4)
assert.Len(t, nodes[0].GivenName, 4)
assert.Len(t, nodes[1].GivenName, 13)
assert.Contains(t, nodes[2].Hostname, "invalid-") // invalid chars
assert.Contains(t, nodes[2].GivenName, "invalid-")
assert.Contains(t, nodes[3].Hostname, "invalid-") // too short
assert.Contains(t, nodes[3].GivenName, "invalid-")
// Nodes can be renamed to a unique name
err = db.Write(func(tx *gorm.DB) error {
@@ -716,7 +746,7 @@ func TestRenameNode(t *testing.T) {
nodes, err = db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Len(t, nodes, 4)
assert.Equal(t, "test", nodes[0].Hostname)
assert.Equal(t, "newname", nodes[0].GivenName)
@@ -728,7 +758,7 @@ func TestRenameNode(t *testing.T) {
nodes, err = db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Len(t, nodes, 4)
assert.Equal(t, "test", nodes[0].Hostname)
assert.Equal(t, "newname", nodes[0].GivenName)
assert.Equal(t, "test", nodes[1].GivenName)
@@ -738,6 +768,149 @@ func TestRenameNode(t *testing.T) {
return RenameNode(tx, nodes[0].ID, "test")
})
assert.ErrorContains(t, err, "name is not unique")
// Rename invalid chars
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[2].ID, "我的电脑")
})
assert.ErrorContains(t, err, "invalid characters")
// Rename too short
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[3].ID, "a")
})
assert.ErrorContains(t, err, "at least 2 characters")
// Rename with emoji
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
})
assert.ErrorContains(t, err, "invalid characters")
// Rename with only emoji
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "🚀")
})
assert.ErrorContains(t, err, "invalid characters")
}
func TestRenameNodeComprehensive(t *testing.T) {
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
node := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
err = db.DB.Save(&node).Error
require.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNodeForTest(tx, node, nil, nil)
return err
})
require.NoError(t, err)
nodes, err := db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 1)
tests := []struct {
name string
newName string
wantErr string
}{
{
name: "uppercase_rejected",
newName: "User2-Host",
wantErr: "must be lowercase",
},
{
name: "underscore_rejected",
newName: "test_node",
wantErr: "invalid characters",
},
{
name: "at_sign_uppercase_rejected",
newName: "Test@Host",
wantErr: "must be lowercase",
},
{
name: "at_sign_rejected",
newName: "test@host",
wantErr: "invalid characters",
},
{
name: "chinese_chars_with_dash_rejected",
newName: "server-北京-01",
wantErr: "invalid characters",
},
{
name: "chinese_only_rejected",
newName: "我的电脑",
wantErr: "invalid characters",
},
{
name: "emoji_with_text_rejected",
newName: "laptop-🚀",
wantErr: "invalid characters",
},
{
name: "mixed_chinese_emoji_rejected",
newName: "测试💻机器",
wantErr: "invalid characters",
},
{
name: "only_emojis_rejected",
newName: "🎉🎊",
wantErr: "invalid characters",
},
{
name: "only_at_signs_rejected",
newName: "@@@",
wantErr: "invalid characters",
},
{
name: "starts_with_dash_rejected",
newName: "-test",
wantErr: "cannot start or end with a hyphen",
},
{
name: "ends_with_dash_rejected",
newName: "test-",
wantErr: "cannot start or end with a hyphen",
},
{
name: "too_long_hostname_rejected",
newName: "this-is-a-very-long-hostname-that-exceeds-sixty-three-characters-limit",
wantErr: "must not exceed 63 characters",
},
{
name: "too_short_hostname_rejected",
newName: "a",
wantErr: "at least 2 characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, tt.newName)
})
assert.ErrorContains(t, err, tt.wantErr)
})
}
}
func TestListPeers(t *testing.T) {

View File

@@ -0,0 +1,40 @@
PRAGMA foreign_keys=OFF;
BEGIN TRANSACTION;
CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`));
INSERT INTO migrations VALUES('202312101416');
INSERT INTO migrations VALUES('202312101430');
INSERT INTO migrations VALUES('202402151347');
INSERT INTO migrations VALUES('2024041121742');
INSERT INTO migrations VALUES('202406021630');
INSERT INTO migrations VALUES('202409271400');
INSERT INTO migrations VALUES('202407191627');
INSERT INTO migrations VALUES('202408181235');
INSERT INTO migrations VALUES('202501221827');
INSERT INTO migrations VALUES('202501311657');
INSERT INTO migrations VALUES('202502070949');
INSERT INTO migrations VALUES('202502131714');
INSERT INTO migrations VALUES('202502171819');
INSERT INTO migrations VALUES('202505091439');
INSERT INTO migrations VALUES('202505141324');
CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text);
CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL);
CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime);
CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`));
CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text);
DELETE FROM sqlite_sequence;
INSERT INTO sqlite_sequence VALUES('nodes',0);
CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`);
CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`);
CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`);
CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;
CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier);
CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;
-- Create all the old tables we have had and ensure they are clean up.
CREATE TABLE `namespaces` (`id` text,PRIMARY KEY (`id`));
CREATE TABLE `machines` (`id` text,PRIMARY KEY (`id`));
CREATE TABLE `kvs` (`id` text,PRIMARY KEY (`id`));
CREATE TABLE `shared_machines` (`id` text,PRIMARY KEY (`id`));
CREATE TABLE `pre_auth_key_acl_tags` (`id` text,PRIMARY KEY (`id`));
CREATE TABLE `routes` (`id` text,PRIMARY KEY (`id`));
COMMIT;

View File

@@ -0,0 +1,14 @@
CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text);
CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`);
CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL);
CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime);
CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`);
CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`));
CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text);
CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`);
CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;
CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier);
CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;
CREATE TABLE _litestream_seq (id INTEGER PRIMARY KEY, seq INTEGER);
CREATE TABLE _litestream_lock (id INTEGER);

View File

@@ -26,8 +26,7 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
err := util.ValidateUsername(user.Name)
if err != nil {
if err := util.ValidateHostname(user.Name); err != nil {
return nil, err
}
if err := tx.Create(&user).Error; err != nil {
@@ -93,8 +92,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
if err != nil {
return err
}
err = util.ValidateUsername(newName)
if err != nil {
if err = util.ValidateHostname(newName); err != nil {
return err
}

View File

@@ -185,7 +185,6 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
}
})
}
}
func TestShuffleDERPMapEdgeCases(t *testing.T) {

View File

@@ -416,9 +416,12 @@ func (api headscaleV1APIServer) ExpireNode(
ctx context.Context,
request *v1.ExpireNodeRequest,
) (*v1.ExpireNodeResponse, error) {
now := time.Now()
expiry := time.Now()
if request.GetExpiry() != nil {
expiry = request.GetExpiry().AsTime()
}
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry)
if err != nil {
return nil, err
}

View File

@@ -201,6 +201,24 @@ func (h *Headscale) RobotsHandler(
}
}
// VersionHandler returns version information about the Headscale server
// Listens in /version.
func (h *Headscale) VersionHandler(
writer http.ResponseWriter,
req *http.Request,
) {
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
versionInfo := types.GetVersionInfo()
if err := json.NewEncoder(writer).Encode(versionInfo); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write version response")
}
}
var codeStyleRegisterWebAPI = styles.Props{
styles.Display: "block",
styles.Padding: "20px",

View File

@@ -73,7 +73,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
nodeConn.removeConnectionByChannel(c)
@@ -602,7 +601,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
mc.updateCount.Add(1)
log.Info().Uint64("node.id", mc.id.Uint64()).
log.Debug().Uint64("node.id", mc.id.Uint64()).
Int("successful_sends", successCount).
Int("failed_connections", len(failedConnections)).
Int("remaining_connections", len(mc.connections)).

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
@@ -181,6 +180,9 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
return b
}
// FilterForNode returns rules already reduced to only those relevant for this node.
// For autogroup:self policies, it returns per-node compiled rules.
// For global policies, it returns the global filter reduced for this node.
filter, err := b.mapper.state.FilterForNode(node)
if err != nil {
b.addError(err)
@@ -192,7 +194,7 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, filter),
"base": filter,
}
return b
@@ -231,18 +233,19 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
return nil, errors.New("node not found")
}
// Use per-node filter to handle autogroup:self
filter, err := b.mapper.state.FilterForNode(node)
// Get unreduced matchers for peer relationship determination.
// MatchersForNode returns unreduced matchers that include all rules where the node
// could be either source or destination. This is different from FilterForNode which
// returns reduced rules for packet filtering (only rules where node is destination).
matchers, err := b.mapper.state.MatchersForNode(node)
if err != nil {
return nil, err
}
matchers := matcher.MatchesFromFilterRules(filter)
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView]
if len(filter) > 0 {
if len(matchers) > 0 {
changedViews = policy.ReduceNodes(node, peers, matchers)
} else {
changedViews = peers

View File

@@ -15,6 +15,10 @@ type PolicyManager interface {
Filter() ([]tailcfg.FilterRule, []matcher.Match)
// FilterForNode returns filter rules for a specific node, handling autogroup:self
FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error)
// MatchersForNode returns matchers for peer relationship determination (unreduced)
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)

View File

@@ -10,7 +10,6 @@ import (
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
)
@@ -79,66 +78,6 @@ func BuildPeerMap(
return ret
}
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
// that are not relevant to that particular node.
func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)
// Fail closed, if we can't parse it, then we should not allow
// access.
if err != nil {
continue DEST_LOOP
}
if node.InIPSet(expanded) {
dests = append(dests, dest)
continue DEST_LOOP
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo().Valid() {
routableIPs := node.Hostinfo().RoutableIPs()
if routableIPs.Len() > 0 {
for _, routableIP := range routableIPs.All() {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
}
// Also check approved subnet routes - nodes should have access
// to subnets they're approved to route traffic for.
subnetRoutes := node.SubnetRoutes()
for _, subnetRoute := range subnetRoutes {
if expanded.OverlapsPrefix(subnetRoute) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
if len(dests) > 0 {
ret = append(ret, tailcfg.FilterRule{
SrcIPs: rule.SrcIPs,
DstPorts: dests,
IPProto: rule.IPProto,
})
}
}
return ret
}
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
// and returns the new list of approved routes.
// The approved routes will include:

View File

@@ -1,7 +1,6 @@
package policy
import (
"encoding/json"
"fmt"
"net/netip"
"testing"
@@ -11,12 +10,9 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
)
var ap = func(ipStr string) *netip.Addr {
@@ -29,817 +25,6 @@ var p = func(prefStr string) netip.Prefix {
return ip
}
// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when
// we use headscale "autogroup:internet".
var hsExitNodeDestForTest = []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "2000::/3", Ports: tailcfg.PortRangeAny},
}
func TestTheInternet(t *testing.T) {
internetSet := util.TheInternet()
internetPrefs := internetSet.Prefixes()
for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP {
t.Errorf(
"prefix from internet set %q != hsExit list %q",
internetPrefs[i].String(),
hsExitNodeDestForTest[i].IP,
)
}
}
if len(internetPrefs) != len(hsExitNodeDestForTest) {
t.Fatalf(
"expected same length of prefixes, internet: %d, hsExit: %d",
len(internetPrefs),
len(hsExitNodeDestForTest),
)
}
}
func TestReduceFilterRules(t *testing.T) {
users := types.Users{
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
types.User{Model: gorm.Model{ID: 2}, Name: "user1"},
types.User{Model: gorm.Model{ID: 3}, Name: "user2"},
types.User{Model: gorm.Model{ID: 4}, Name: "user100"},
types.User{Model: gorm.Model{ID: 5}, Name: "user3"},
}
tests := []struct {
name string
node *types.Node
peers types.Nodes
pol string
want []tailcfg.FilterRule
}{
{
name: "host1-can-reach-host2-no-rules",
pol: `
{
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"100.64.0.1"
],
"dst": [
"100.64.0.2:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: users[0],
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1604-subnet-routers-are-preserved",
pol: `
{
"groups": {
"group:admins": [
"user1@"
]
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"group:admins:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"10.33.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.1/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::1/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.33.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-client",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1786-reducing-breaks-exit-nodes-the-exit",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDestForTest,
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-example-from-issue",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"0.0.0.0/5:*",
"8.0.0.0/7:*",
"11.0.0.0/8:*",
"12.0.0.0/6:*",
"16.0.0.0/4:*",
"32.0.0.0/3:*",
"64.0.0.0/2:*",
"128.0.0.0/3:*",
"160.0.0.0/5:*",
"168.0.0.0/6:*",
"172.0.0.0/12:*",
"172.32.0.0/11:*",
"172.64.0.0/10:*",
"172.128.0.0/9:*",
"173.0.0.0/8:*",
"174.0.0.0/7:*",
"176.0.0.0/4:*",
"192.0.0.0/9:*",
"192.128.0.0/11:*",
"192.160.0.0/13:*",
"192.169.0.0/16:*",
"192.170.0.0/15:*",
"192.172.0.0/14:*",
"192.176.0.0/12:*",
"192.192.0.0/10:*",
"193.0.0.0/8:*",
"194.0.0.0/7:*",
"196.0.0.0/6:*",
"200.0.0.0/5:*",
"208.0.0.0/4:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/8:*",
"16.0.0.0/8:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like2",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/16:*",
"16.0.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1817-reduce-breaks-32-mask",
pol: `
{
"tagOwners": {
"tag:access-servers": ["user100@"],
},
"groups": {
"group:access": [
"user1@"
]
},
"hosts": {
"dns1": "172.16.0.21/32",
"vlan1": "172.16.0.0/24"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:access"
],
"dst": [
"tag:access-servers:*",
"dns1:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
{
IP: "172.16.0.21/32",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "2365-only-route-policy",
pol: `
{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
},
"acls": [
{
"action": "accept",
"src": [
"*"
],
"dst": [
"router:8000"
]
},
{
"action": "accept",
"src": [
"node"
],
"dst": [
"172.26.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[3],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
},
want: []tailcfg.FilterRule{},
},
}
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.pol)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager
var err error
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err)
got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = ReduceFilterRules(tt.node.View(), got)
if diff := cmp.Diff(tt.want, got); diff != "" {
log.Trace().Interface("got", got).Msg("result")
t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff)
}
})
}
}
}
func TestReduceNodes(t *testing.T) {
type args struct {
nodes types.Nodes

View File

@@ -0,0 +1,71 @@
package policyutil
import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
)
// ReduceFilterRules takes a node and a set of global filter rules and removes all rules
// and destinations that are not relevant to that particular node.
//
// IMPORTANT: This function is designed for global filters only. Per-node filters
// (from autogroup:self policies) are already node-specific and should not be passed
// to this function. Use PolicyManager.FilterForNode() instead, which handles both cases.
func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
ret := []tailcfg.FilterRule{}
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)
// Fail closed, if we can't parse it, then we should not allow
// access.
if err != nil {
continue DEST_LOOP
}
if node.InIPSet(expanded) {
dests = append(dests, dest)
continue DEST_LOOP
}
// If the node exposes routes, ensure they are note removed
// when the filters are reduced.
if node.Hostinfo().Valid() {
routableIPs := node.Hostinfo().RoutableIPs()
if routableIPs.Len() > 0 {
for _, routableIP := range routableIPs.All() {
if expanded.OverlapsPrefix(routableIP) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
}
// Also check approved subnet routes - nodes should have access
// to subnets they're approved to route traffic for.
subnetRoutes := node.SubnetRoutes()
for _, subnetRoute := range subnetRoutes {
if expanded.OverlapsPrefix(subnetRoute) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
if len(dests) > 0 {
ret = append(ret, tailcfg.FilterRule{
SrcIPs: rule.SrcIPs,
DstPorts: dests,
IPProto: rule.IPProto,
})
}
}
return ret
}

View File

@@ -0,0 +1,841 @@
package policyutil_test
import (
"encoding/json"
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/policyutil"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
)
var ap = func(ipStr string) *netip.Addr {
ip := netip.MustParseAddr(ipStr)
return &ip
}
var p = func(prefStr string) netip.Prefix {
ip := netip.MustParsePrefix(prefStr)
return ip
}
// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when
// we use headscale "autogroup:internet".
var hsExitNodeDestForTest = []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "2000::/3", Ports: tailcfg.PortRangeAny},
}
func TestTheInternet(t *testing.T) {
internetSet := util.TheInternet()
internetPrefs := internetSet.Prefixes()
for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP {
t.Errorf(
"prefix from internet set %q != hsExit list %q",
internetPrefs[i].String(),
hsExitNodeDestForTest[i].IP,
)
}
}
if len(internetPrefs) != len(hsExitNodeDestForTest) {
t.Fatalf(
"expected same length of prefixes, internet: %d, hsExit: %d",
len(internetPrefs),
len(hsExitNodeDestForTest),
)
}
}
func TestReduceFilterRules(t *testing.T) {
users := types.Users{
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
types.User{Model: gorm.Model{ID: 2}, Name: "user1"},
types.User{Model: gorm.Model{ID: 3}, Name: "user2"},
types.User{Model: gorm.Model{ID: 4}, Name: "user100"},
types.User{Model: gorm.Model{ID: 5}, Name: "user3"},
}
tests := []struct {
name string
node *types.Node
peers types.Nodes
pol string
want []tailcfg.FilterRule
}{
{
name: "host1-can-reach-host2-no-rules",
pol: `
{
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"100.64.0.1"
],
"dst": [
"100.64.0.2:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: users[0],
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1604-subnet-routers-are-preserved",
pol: `
{
"groups": {
"group:admins": [
"user1@"
]
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"group:admins:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"10.33.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.1/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::1/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{
"100.64.0.1/32",
"100.64.0.2/32",
"fd7a:115c:a1e0::1/128",
"fd7a:115c:a1e0::2/128",
},
DstPorts: []tailcfg.NetPortRange{
{
IP: "10.33.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-client",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
},
want: []tailcfg.FilterRule{},
},
{
name: "1786-reducing-breaks-exit-nodes-the-exit",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDestForTest,
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-example-from-issue",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"0.0.0.0/5:*",
"8.0.0.0/7:*",
"11.0.0.0/8:*",
"12.0.0.0/6:*",
"16.0.0.0/4:*",
"32.0.0.0/3:*",
"64.0.0.0/2:*",
"128.0.0.0/3:*",
"160.0.0.0/5:*",
"168.0.0.0/6:*",
"172.0.0.0/12:*",
"172.32.0.0/11:*",
"172.64.0.0/10:*",
"172.128.0.0/9:*",
"173.0.0.0/8:*",
"174.0.0.0/7:*",
"176.0.0.0/4:*",
"192.0.0.0/9:*",
"192.128.0.0/11:*",
"192.160.0.0/13:*",
"192.169.0.0/16:*",
"192.170.0.0/15:*",
"192.172.0.0/14:*",
"192.176.0.0/12:*",
"192.192.0.0/10:*",
"193.0.0.0/8:*",
"194.0.0.0/7:*",
"196.0.0.0/6:*",
"200.0.0.0/5:*",
"208.0.0.0/4:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny},
{IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny},
{IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny},
{IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny},
{IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny},
{IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny},
{IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny},
{IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny},
{IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny},
{IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny},
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/8:*",
"16.0.0.0/8:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/8",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like2",
pol: `
{
"groups": {
"group:team": [
"user3@",
"user2@",
"user1@"
]
},
"hosts": {
"internal": "100.64.0.100/32"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/16:*",
"16.0.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "8.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
{
IP: "16.0.0.0/16",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "1817-reduce-breaks-32-mask",
pol: `
{
"tagOwners": {
"tag:access-servers": ["user100@"],
},
"groups": {
"group:access": [
"user1@"
]
},
"hosts": {
"dns1": "172.16.0.21/32",
"vlan1": "172.16.0.0/24"
},
"acls": [
{
"action": "accept",
"proto": "",
"src": [
"group:access"
],
"dst": [
"tag:access-servers:*",
"dns1:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
},
},
want: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"},
DstPorts: []tailcfg.NetPortRange{
{
IP: "100.64.0.100/32",
Ports: tailcfg.PortRangeAny,
},
{
IP: "fd7a:115c:a1e0::100/128",
Ports: tailcfg.PortRangeAny,
},
{
IP: "172.16.0.21/32",
Ports: tailcfg.PortRangeAny,
},
},
IPProto: []int{6, 17},
},
},
},
{
name: "2365-only-route-policy",
pol: `
{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
},
"acls": [
{
"action": "accept",
"src": [
"*"
],
"dst": [
"router:8000"
]
},
{
"action": "accept",
"src": [
"node"
],
"dst": [
"172.26.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[3],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
},
want: []tailcfg.FilterRule{},
},
}
for _, tt := range tests {
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm policy.PolicyManager
var err error
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err)
got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = policyutil.ReduceFilterRules(tt.node.View(), got)
if diff := cmp.Diff(tt.want, got); diff != "" {
log.Trace().Interface("got", got).Msg("result")
t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff)
}
})
}
}
}

View File

@@ -854,7 +854,6 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
node1 := nodes[0].View()
rules, err := policy2.compileFilterRulesForNode(users, node1, nodes.ViewSlice())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

View File

@@ -9,6 +9,7 @@ import (
"sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/policy/policyutil"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"go4.org/netipx"
@@ -39,7 +40,9 @@ type PolicyManager struct {
// Lazy map of SSH policies
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
// Lazy map of per-node filter rules (when autogroup:self is used)
// Lazy map of per-node compiled filter rules (unreduced, for autogroup:self)
compiledFilterRulesMap map[types.NodeID][]tailcfg.FilterRule
// Lazy map of per-node filter rules (reduced, for packet filters)
filterRulesMap map[types.NodeID][]tailcfg.FilterRule
usesAutogroupSelf bool
}
@@ -54,12 +57,13 @@ func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.Node
}
pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()),
usesAutogroupSelf: policy.usesAutogroupSelf(),
pol: policy,
users: users,
nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
compiledFilterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()),
filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()),
usesAutogroupSelf: policy.usesAutogroupSelf(),
}
_, err = pm.updateLocked()
@@ -78,6 +82,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
// policies for nodes that have changed. Particularly if the only difference is
// that nodes has been added or removed.
clear(pm.sshPolicyMap)
clear(pm.compiledFilterRulesMap)
clear(pm.filterRulesMap)
// Check if policy uses autogroup:self
@@ -233,9 +238,157 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
return pm.filter, pm.matchers
}
// FilterForNode returns the filter rules for a specific node.
// If the policy uses autogroup:self, this returns node-specific rules for security.
// Otherwise, it returns the global filter rules for efficiency.
// BuildPeerMap constructs peer relationship maps for the given nodes.
// For global filters, it uses the global filter matchers for all nodes.
// For autogroup:self policies (empty global filter), it builds per-node
// peer maps using each node's specific filter rules.
func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView {
if pm == nil {
return nil
}
pm.mu.Lock()
defer pm.mu.Unlock()
// If we have a global filter, use it for all nodes (normal case)
if !pm.usesAutogroupSelf {
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
// Build the map of all peers according to the matchers.
// Compared to ReduceNodes, which builds the list per node, we end up with doing
// the full work for every node O(n^2), while this will reduce the list as we see
// relationships while building the map, making it O(n^2/2) in the end, but with less work per node.
for i := range nodes.Len() {
for j := i + 1; j < nodes.Len(); j++ {
if nodes.At(i).ID() == nodes.At(j).ID() {
continue
}
if nodes.At(i).CanAccess(pm.matchers, nodes.At(j)) || nodes.At(j).CanAccess(pm.matchers, nodes.At(i)) {
ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j))
ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i))
}
}
}
return ret
}
// For autogroup:self (empty global filter), build per-node peer relationships
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
// Pre-compute per-node matchers using unreduced compiled rules
// We need unreduced rules to determine peer relationships correctly.
// Reduced rules only show destinations where the node is the target,
// but peer relationships require the full bidirectional access rules.
nodeMatchers := make(map[types.NodeID][]matcher.Match, nodes.Len())
for _, node := range nodes.All() {
filter, err := pm.compileFilterRulesForNodeLocked(node)
if err != nil || len(filter) == 0 {
continue
}
nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter)
}
// Check each node pair for peer relationships.
// Start j at i+1 to avoid checking the same pair twice and creating duplicates.
// We check both directions (i->j and j->i) since ACLs can be asymmetric.
for i := range nodes.Len() {
nodeI := nodes.At(i)
matchersI, hasFilterI := nodeMatchers[nodeI.ID()]
for j := i + 1; j < nodes.Len(); j++ {
nodeJ := nodes.At(j)
matchersJ, hasFilterJ := nodeMatchers[nodeJ.ID()]
// Check if nodeI can access nodeJ
if hasFilterI && nodeI.CanAccess(matchersI, nodeJ) {
ret[nodeI.ID()] = append(ret[nodeI.ID()], nodeJ)
}
// Check if nodeJ can access nodeI
if hasFilterJ && nodeJ.CanAccess(matchersJ, nodeI) {
ret[nodeJ.ID()] = append(ret[nodeJ.ID()], nodeI)
}
}
}
return ret
}
// compileFilterRulesForNodeLocked returns the unreduced compiled filter rules for a node
// when using autogroup:self. This is used by BuildPeerMap to determine peer relationships.
// For packet filters sent to nodes, use filterForNodeLocked which returns reduced rules.
func (pm *PolicyManager) compileFilterRulesForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) {
if pm == nil {
return nil, nil
}
// Check if we have cached compiled rules
if rules, ok := pm.compiledFilterRulesMap[node.ID()]; ok {
return rules, nil
}
// Compile per-node rules with autogroup:self expanded
rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes)
if err != nil {
return nil, fmt.Errorf("compiling filter rules for node: %w", err)
}
// Cache the unreduced compiled rules
pm.compiledFilterRulesMap[node.ID()] = rules
return rules, nil
}
// filterForNodeLocked returns the filter rules for a specific node, already reduced
// to only include rules relevant to that node.
// This is a lock-free version of FilterForNode for internal use when the lock is already held.
// BuildPeerMap already holds the lock, so we need a version that doesn't re-acquire it.
func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) {
if pm == nil {
return nil, nil
}
if !pm.usesAutogroupSelf {
// For global filters, reduce to only rules relevant to this node.
// Cache the reduced filter per node for efficiency.
if rules, ok := pm.filterRulesMap[node.ID()]; ok {
return rules, nil
}
// Use policyutil.ReduceFilterRules for global filter reduction.
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
pm.filterRulesMap[node.ID()] = reducedFilter
return reducedFilter, nil
}
// For autogroup:self, compile per-node rules then reduce them.
// Check if we have cached reduced rules for this node.
if rules, ok := pm.filterRulesMap[node.ID()]; ok {
return rules, nil
}
// Get unreduced compiled rules
compiledRules, err := pm.compileFilterRulesForNodeLocked(node)
if err != nil {
return nil, err
}
// Reduce the compiled rules to only destinations relevant to this node
reducedFilter := policyutil.ReduceFilterRules(node, compiledRules)
// Cache the reduced filter
pm.filterRulesMap[node.ID()] = reducedFilter
return reducedFilter, nil
}
// FilterForNode returns the filter rules for a specific node, already reduced
// to only include rules relevant to that node.
// If the policy uses autogroup:self, this returns node-specific compiled rules.
// Otherwise, it returns the global filter reduced for this node.
func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) {
if pm == nil {
return nil, nil
@@ -244,22 +397,36 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filterForNodeLocked(node)
}
// MatchersForNode returns the matchers for peer relationship determination for a specific node.
// These are UNREDUCED matchers - they include all rules where the node could be either source or destination.
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
//
// For global policies: returns the global matchers (same for all nodes)
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
if pm == nil {
return nil, nil
}
pm.mu.Lock()
defer pm.mu.Unlock()
// For global policies, return the shared global matchers
if !pm.usesAutogroupSelf {
return pm.filter, nil
return pm.matchers, nil
}
if rules, ok := pm.filterRulesMap[node.ID()]; ok {
return rules, nil
}
rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes)
// For autogroup:self, get unreduced compiled rules and create matchers
compiledRules, err := pm.compileFilterRulesForNodeLocked(node)
if err != nil {
return nil, fmt.Errorf("compiling filter rules for node: %w", err)
return nil, err
}
pm.filterRulesMap[node.ID()] = rules
return rules, nil
// Create matchers from unreduced rules for peer relationship determination
return matcher.MatchesFromFilterRules(compiledRules), nil
}
// SetUsers updates the users in the policy manager and updates the filter rules.
@@ -300,22 +467,40 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
// Clear cache based on what actually changed
if pm.usesAutogroupSelf {
// For autogroup:self, we need granular invalidation since rules depend on:
// - User ownership (node.User().ID)
// - Tag status (node.IsTagged())
// - IP addresses (node.IPs())
// - Node existence (added/removed)
pm.invalidateAutogroupSelfCache(pm.nodes, nodes)
} else {
// For non-autogroup:self policies, we can clear everything
clear(pm.filterRulesMap)
}
oldNodeCount := pm.nodes.Len()
newNodeCount := nodes.Len()
// Invalidate cache entries for nodes that changed.
// For autogroup:self: invalidate all nodes belonging to affected users (peer changes).
// For global policies: invalidate only nodes whose properties changed (IPs, routes).
pm.invalidateNodeCache(nodes)
pm.nodes = nodes
return pm.updateLocked()
nodesChanged := oldNodeCount != newNodeCount
// When nodes are added/removed, we must recompile filters because:
// 1. User/group aliases (like "user1@") resolve to node IPs
// 2. Filter compilation needs nodes to generate rules
// 3. Without nodes, filters compile to empty (0 rules)
//
// For autogroup:self: return true when nodes change even if the global filter
// hash didn't change. The global filter is empty for autogroup:self (each node
// has its own filter), so the hash never changes. But peer relationships DO
// change when nodes are added/removed, so we must signal this to trigger updates.
// For global policies: the filter must be recompiled to include the new nodes.
if nodesChanged {
// Recompile filter with the new node list
_, err := pm.updateLocked()
if err != nil {
return false, err
}
// Always return true when nodes changed, even if filter hash didn't change
// (can happen with autogroup:self or when nodes are added but don't affect rules)
return true, nil
}
return false, nil
}
func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool {
@@ -552,10 +737,12 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// If we found the user and they're affected, clear this cache entry
if found {
if _, affected := affectedUsers[nodeUserID]; affected {
delete(pm.compiledFilterRulesMap, nodeID)
delete(pm.filterRulesMap, nodeID)
}
} else {
// Node not found in either old or new list, clear it
delete(pm.compiledFilterRulesMap, nodeID)
delete(pm.filterRulesMap, nodeID)
}
}
@@ -567,3 +754,50 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
Msg("Selectively cleared autogroup:self cache for affected users")
}
}
// invalidateNodeCache invalidates cache entries based on what changed.
func (pm *PolicyManager) invalidateNodeCache(newNodes views.Slice[types.NodeView]) {
if pm.usesAutogroupSelf {
// For autogroup:self, a node's filter depends on its peers (same user).
// When any node in a user changes, all nodes for that user need invalidation.
pm.invalidateAutogroupSelfCache(pm.nodes, newNodes)
} else {
// For global policies, a node's filter depends only on its own properties.
// Only invalidate nodes whose properties actually changed.
pm.invalidateGlobalPolicyCache(newNodes)
}
}
// invalidateGlobalPolicyCache invalidates only nodes whose properties affecting
// ReduceFilterRules changed. For global policies, each node's filter is independent.
func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.NodeView]) {
oldNodeMap := make(map[types.NodeID]types.NodeView)
for _, node := range pm.nodes.All() {
oldNodeMap[node.ID()] = node
}
newNodeMap := make(map[types.NodeID]types.NodeView)
for _, node := range newNodes.All() {
newNodeMap[node.ID()] = node
}
// Invalidate nodes whose properties changed
for nodeID, newNode := range newNodeMap {
oldNode, existed := oldNodeMap[nodeID]
if !existed {
// New node - no cache entry yet, will be lazily calculated
continue
}
if newNode.HasNetworkChanges(oldNode) {
delete(pm.filterRulesMap, nodeID)
}
}
// Remove deleted nodes from cache
for nodeID := range pm.filterRulesMap {
if _, exists := newNodeMap[nodeID]; !exists {
delete(pm.filterRulesMap, nodeID)
}
}
}

View File

@@ -1,6 +1,7 @@
package v2
import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
@@ -204,3 +205,237 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
})
}
}
// TestInvalidateGlobalPolicyCache tests the cache invalidation logic for global policies.
func TestInvalidateGlobalPolicyCache(t *testing.T) {
mustIPPtr := func(s string) *netip.Addr {
ip := netip.MustParseAddr(s)
return &ip
}
tests := []struct {
name string
oldNodes types.Nodes
newNodes types.Nodes
initialCache map[types.NodeID][]tailcfg.FilterRule
expectedCacheAfter map[types.NodeID]bool // true = should exist, false = should not exist
}{
{
name: "node property changed - invalidates only that node",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Invalidated
2: true, // Preserved
},
},
{
name: "multiple nodes changed",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
&types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.3")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged
&types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.88")}, // Changed
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
3: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Invalidated
2: true, // Preserved
3: false, // Invalidated
},
},
{
name: "node deleted - removes from cache",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Deleted
2: true, // Preserved
},
},
{
name: "node added - no cache invalidation needed",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // New
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: true, // Preserved
2: false, // Not in cache (new node)
},
},
{
name: "no changes - preserves all cache",
oldNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: true,
2: true,
},
},
{
name: "routes changed - invalidates that node only",
oldNodes: types.Nodes{
&types.Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
newNodes: types.Nodes{
&types.Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, // Changed
},
&types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")},
},
initialCache: map[types.NodeID][]tailcfg.FilterRule{
1: {},
2: {},
},
expectedCacheAfter: map[types.NodeID]bool{
1: false, // Invalidated
2: true, // Preserved
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm := &PolicyManager{
nodes: tt.oldNodes.ViewSlice(),
filterRulesMap: tt.initialCache,
usesAutogroupSelf: false,
}
pm.invalidateGlobalPolicyCache(tt.newNodes.ViewSlice())
// Verify cache state
for nodeID, shouldExist := range tt.expectedCacheAfter {
_, exists := pm.filterRulesMap[nodeID]
require.Equal(t, shouldExist, exists, "node %d cache existence mismatch", nodeID)
}
})
}
}
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
// 2. FilterForNode returns reduced compiled rules for packet filters
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
users := types.Users{user1, user2}
// Create two nodes
node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil)
node1.ID = 1
node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil)
node2.ID = 2
nodes := types.Nodes{node1, node2}
// Policy with autogroup:self - all members can reach their own devices
policyStr := `{
"acls": [
{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["autogroup:self:*"]
}
]
}`
pm, err := NewPolicyManager([]byte(policyStr), users, nodes.ViewSlice())
require.NoError(t, err)
require.True(t, pm.usesAutogroupSelf, "policy should use autogroup:self")
// Test FilterForNode returns reduced rules
// For node1: should have rules where node1 is in destinations (its own IP)
filterNode1, err := pm.FilterForNode(nodes[0].View())
require.NoError(t, err)
// For node2: should have rules where node2 is in destinations (its own IP)
filterNode2, err := pm.FilterForNode(nodes[1].View())
require.NoError(t, err)
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
// For node1, destinations should only be node1's IPs
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
for _, rule := range filterNode1 {
for _, dst := range rule.DstPorts {
require.Contains(t, node1IPs, dst.IP,
"node1 filter should only contain node1's IPs as destinations")
}
}
// For node2, destinations should only be node2's IPs
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
for _, rule := range filterNode2 {
for _, dst := range rule.DstPorts {
require.Contains(t, node2IPs, dst.IP,
"node2 filter should only contain node2's IPs as destinations")
}
}
// Test BuildPeerMap uses unreduced rules
peerMap := pm.BuildPeerMap(nodes.ViewSlice())
// According to the policy, user1 can reach autogroup:self (which expands to node1's own IPs for node1)
// So node1 should be able to reach itself, but since we're looking at peer relationships,
// node1 should NOT have itself in the peer map (nodes don't peer with themselves)
// node2 should also not have any peers since user2 has no rules allowing it to reach anyone
// Verify peer relationships based on unreduced rules
// With unreduced rules, BuildPeerMap can properly determine that:
// - node1 can access autogroup:self (its own IPs)
// - node2 cannot access node1
require.Empty(t, peerMap[node1.ID], "node1 should have no peers (can only reach itself)")
require.Empty(t, peerMap[node2.ID], "node2 should have no peers")
}

View File

@@ -20,9 +20,10 @@ const (
)
const (
put = 1
del = 2
update = 3
put = 1
del = 2
update = 3
rebuildPeerMaps = 4
)
const prometheusNamespace = "headscale"
@@ -142,6 +143,8 @@ type work struct {
updateFn UpdateNodeFunc
result chan struct{}
nodeResult chan types.NodeView // Channel to return the resulting node after batch application
// For rebuildPeerMaps operation
rebuildResult chan struct{}
}
// PutNode adds or updates a node in the store.
@@ -298,6 +301,9 @@ func (s *NodeStore) applyBatch(batch []work) {
// Track which work items need node results
nodeResultRequests := make(map[types.NodeID][]*work)
// Track rebuildPeerMaps operations
var rebuildOps []*work
for i := range batch {
w := &batch[i]
switch w.op {
@@ -321,6 +327,10 @@ func (s *NodeStore) applyBatch(batch []work) {
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
case rebuildPeerMaps:
// rebuildPeerMaps doesn't modify nodes, it just forces the snapshot rebuild
// below to recalculate peer relationships using the current peersFunc
rebuildOps = append(rebuildOps, w)
}
}
@@ -347,9 +357,16 @@ func (s *NodeStore) applyBatch(batch []work) {
}
}
// Signal completion for all work items
// Signal completion for rebuildPeerMaps operations
for _, w := range rebuildOps {
close(w.rebuildResult)
}
// Signal completion for all other work items
for _, w := range batch {
close(w.result)
if w.op != rebuildPeerMaps {
close(w.result)
}
}
}
@@ -546,6 +563,22 @@ func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
return views.SliceOf(s.data.Load().peersByNode[id])
}
// RebuildPeerMaps rebuilds the peer relationship map using the current peersFunc.
// This must be called after policy changes because peersFunc uses PolicyManager's
// filters to determine which nodes can see each other. Without rebuilding, the
// peer map would use stale filter data until the next node add/delete.
func (s *NodeStore) RebuildPeerMaps() {
result := make(chan struct{})
w := work{
op: rebuildPeerMaps,
rebuildResult: result,
}
s.writeQueue <- w
<-result
}
// ListNodesByUser returns a slice of all nodes for a given user ID.
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))

View File

@@ -132,9 +132,10 @@ func NewState(cfg *types.Config) (*State, error) {
return nil, fmt.Errorf("init policy manager: %w", err)
}
// PolicyManager.BuildPeerMap handles both global and per-node filter complexity.
// This moves the complex peer relationship logic into the policy package where it belongs.
nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
_, matchers := polMan.Filter()
return policy.BuildPeerMap(views.SliceOf(nodes), matchers)
return polMan.BuildPeerMap(views.SliceOf(nodes))
})
nodeStore.Start()
@@ -225,6 +226,12 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
return nil, fmt.Errorf("setting policy: %w", err)
}
// Rebuild peer maps after policy changes because the peersFunc in NodeStore
// uses the PolicyManager's filters. Without this, nodes won't see newly allowed
// peers until a node is added/removed, causing autogroup:self policies to not
// propagate correctly when switching between policy types.
s.nodeStore.RebuildPeerMaps()
cs := []change.ChangeSet{change.PolicyChange()}
// Always call autoApproveNodes during policy reload, regardless of whether
@@ -662,8 +669,7 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) {
// Validate the new name before making any changes
if err := util.CheckForFQDNRules(newName); err != nil {
if err := util.ValidateHostname(newName); err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
}
@@ -798,6 +804,11 @@ func (s *State) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error)
return s.polMan.FilterForNode(node)
}
// MatchersForNode returns matchers for peer relationship determination (unreduced).
func (s *State) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
return s.polMan.MatchersForNode(node)
}
// NodeCanHaveTag checks if a node is allowed to have a specific tag.
func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool {
return s.polMan.NodeCanHaveTag(node, tag)
@@ -1112,13 +1123,17 @@ func (s *State) HandleNodeFromAuthPath(
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
}
// Ensure we have valid hostinfo and hostname from the registration cache entry
validHostinfo, hostname := util.EnsureValidHostinfo(
// Ensure we have a valid hostname from the registration cache entry
hostname := util.EnsureHostname(
regEntry.Node.Hostinfo,
regEntry.Node.MachineKey.String(),
regEntry.Node.NodeKey.String(),
)
// Ensure we have valid hostinfo
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
@@ -1284,13 +1299,17 @@ func (s *State) HandleNodeFromPreAuthKey(
return types.NodeView{}, change.EmptySet, err
}
// Ensure we have valid hostinfo and hostname - handle nil/empty cases
validHostinfo, hostname := util.EnsureValidHostinfo(
// Ensure we have a valid hostname - handle nil/empty cases
hostname := util.EnsureHostname(
regReq.Hostinfo,
machineKey.String(),
regReq.NodeKey.String(),
)
// Ensure we have valid hostinfo
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
machineKey.ShortString(),
regReq.NodeKey.ShortString(),

View File

@@ -340,7 +340,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if err := viper.ReadInConfig(); err != nil {
if errors.Is(err, fs.ErrNotExist) {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
log.Warn().Msg("No config file found, using defaults")
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/netip"
"regexp"
"slices"
"sort"
"strconv"
@@ -27,6 +28,8 @@ var (
ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars")
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
type (
@@ -144,7 +147,10 @@ func (ns Nodes) ViewSlice() views.Slice[NodeView] {
// GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node.
func (node *Node) GivenNameHasBeenChanged() bool {
return node.GivenName == util.ConvertWithFQDNRules(node.Hostname)
// Strip invalid DNS characters for givenName comparison
normalised := strings.ToLower(node.Hostname)
normalised = invalidDNSRegex.ReplaceAllString(normalised, "")
return node.GivenName == normalised
}
// IsExpired returns whether the node registration has expired.
@@ -531,20 +537,34 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
return
}
if node.Hostname != hostInfo.Hostname {
newHostname := strings.ToLower(hostInfo.Hostname)
if err := util.ValidateHostname(newHostname); err != nil {
log.Warn().
Str("node.id", node.ID.String()).
Str("current_hostname", node.Hostname).
Str("rejected_hostname", hostInfo.Hostname).
Err(err).
Msg("Rejecting invalid hostname update from hostinfo")
return
}
if node.Hostname != newHostname {
log.Trace().
Str("node.id", node.ID.String()).
Str("old_hostname", node.Hostname).
Str("new_hostname", hostInfo.Hostname).
Str("new_hostname", newHostname).
Str("old_given_name", node.GivenName).
Bool("given_name_changed", node.GivenNameHasBeenChanged()).
Msg("Updating hostname from hostinfo")
if node.GivenNameHasBeenChanged() {
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
// Strip invalid DNS characters for givenName display
givenName := strings.ToLower(newHostname)
givenName = invalidDNSRegex.ReplaceAllString(givenName, "")
node.GivenName = givenName
}
node.Hostname = hostInfo.Hostname
node.Hostname = newHostname
log.Trace().
Str("node.id", node.ID.String()).
@@ -835,3 +855,22 @@ func (v NodeView) IPsAsString() []string {
}
return v.ж.IPsAsString()
}
// HasNetworkChanges checks if the node has network-related changes.
// Returns true if IPs, announced routes, or approved routes changed.
// This is primarily used for policy cache invalidation.
func (v NodeView) HasNetworkChanges(other NodeView) bool {
if !slices.Equal(v.IPs(), other.IPs()) {
return true
}
if !slices.Equal(v.AnnouncedRoutes(), other.AnnouncedRoutes()) {
return true
}
if !slices.Equal(v.SubnetRoutes(), other.SubnetRoutes()) {
return true
}
return false
}

View File

@@ -369,7 +369,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
},
want: Node{
GivenName: "manual-test.local",
Hostname: "NewHostName.Local",
Hostname: "newhostname.local",
},
},
{
@@ -383,7 +383,245 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
},
want: Node{
GivenName: "newhostname.local",
Hostname: "NewHostName.Local",
Hostname: "newhostname.local",
},
},
{
name: "invalid-hostname-with-emoji-rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "hostname-with-💩",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname", // Should reject and keep old hostname
},
},
{
name: "invalid-hostname-with-unicode-rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "我的电脑",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname", // Should keep old hostname
},
},
{
name: "invalid-hostname-with-special-chars-rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "node-with-special!@#$%",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname", // Should reject and keep old hostname
},
},
{
name: "invalid-hostname-too-short-rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "a",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname", // Should keep old hostname
},
},
{
name: "invalid-hostname-uppercase-accepted-lowercased",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "ValidHostName",
},
want: Node{
GivenName: "validhostname", // GivenName follows hostname when it changes
Hostname: "validhostname", // Uppercase is lowercased, not rejected
},
},
{
name: "uppercase_to_lowercase_accepted",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "User2-Host",
},
want: Node{
GivenName: "user2-host",
Hostname: "user2-host",
},
},
{
name: "at_sign_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "Test@Host",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "chinese_chars_with_dash_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "server-北京-01",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "chinese_only_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "我的电脑",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "emoji_with_text_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "laptop-🚀",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "mixed_chinese_emoji_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "测试💻机器",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "only_emojis_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "🎉🎊",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "only_at_signs_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "@@@",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "starts_with_dash_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "-test",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "ends_with_dash_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "test-",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "too_long_hostname_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: strings.Repeat("t", 65),
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
{
name: "underscore_rejected",
nodeBefore: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "test_node",
},
want: Node{
GivenName: "valid-hostname",
Hostname: "valid-hostname",
},
},
}
@@ -555,3 +793,179 @@ func TestNodeRegisterMethodToV1Enum(t *testing.T) {
})
}
}
// TestHasNetworkChanges tests the NodeView method for detecting
// when a node's network properties have changed.
func TestHasNetworkChanges(t *testing.T) {
mustIPPtr := func(s string) *netip.Addr {
ip := netip.MustParseAddr(s)
return &ip
}
tests := []struct {
name string
old *Node
new *Node
changed bool
}{
{
name: "no changes",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
changed: false,
},
{
name: "IPv4 changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.2"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
},
changed: true,
},
{
name: "IPv6 changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::1"),
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
IPv6: mustIPPtr("fd7a:115c:a1e0::2"),
},
changed: true,
},
{
name: "RoutableIPs added",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
},
changed: true,
},
{
name: "RoutableIPs removed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{},
},
changed: true,
},
{
name: "RoutableIPs changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
},
changed: true,
},
{
name: "SubnetRoutes added",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
changed: true,
},
{
name: "SubnetRoutes removed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{},
},
changed: true,
},
{
name: "SubnetRoutes changed",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}},
ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
changed: true,
},
{
name: "irrelevant property changed (Hostname)",
old: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostname: "old-name",
},
new: &Node{
ID: 1,
IPv4: mustIPPtr("100.64.0.1"),
Hostname: "new-name",
},
changed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.new.View().HasNetworkChanges(tt.old.View())
if got != tt.changed {
t.Errorf("HasNetworkChanges() = %v, want %v", got, tt.changed)
}
})
}
}

View File

@@ -27,7 +27,7 @@ var (
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
var ErrInvalidUserName = errors.New("invalid user name")
var ErrInvalidHostName = errors.New("invalid hostname")
// ValidateUsername checks if a username is valid.
// It must be at least 2 characters long, start with a letter, and contain
@@ -67,42 +67,86 @@ func ValidateUsername(username string) error {
return nil
}
func CheckForFQDNRules(name string) error {
// Ensure the username meets the minimum length requirement
// ValidateHostname checks if a hostname meets DNS requirements.
// This function does NOT modify the input - it only validates.
// The hostname must already be lowercase and contain only valid characters.
func ValidateHostname(name string) error {
if len(name) < 2 {
return errors.New("name must be at least 2 characters long")
return fmt.Errorf(
"hostname %q is too short, must be at least 2 characters",
name,
)
}
if len(name) > LabelHostnameLength {
return fmt.Errorf(
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
"hostname %q is too long, must not exceed 63 characters",
name,
ErrInvalidUserName,
)
}
if strings.ToLower(name) != name {
return fmt.Errorf(
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
"hostname %q must be lowercase (try %q)",
name,
strings.ToLower(name),
)
}
if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") {
return fmt.Errorf(
"hostname %q cannot start or end with a hyphen",
name,
)
}
if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
return fmt.Errorf(
"hostname %q cannot start or end with a dot",
name,
ErrInvalidUserName,
)
}
if invalidDNSRegex.MatchString(name) {
return fmt.Errorf(
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with these rules: %w",
"hostname %q contains invalid characters, only lowercase letters, numbers, hyphens and dots are allowed",
name,
ErrInvalidUserName,
)
}
return nil
}
func ConvertWithFQDNRules(name string) string {
// NormaliseHostname transforms a string into a valid DNS hostname.
// Returns error if the transformation results in an invalid hostname.
//
// Transformations applied:
// - Converts to lowercase
// - Removes invalid DNS characters
// - Truncates to 63 characters if needed
//
// After transformation, validates the result.
func NormaliseHostname(name string) (string, error) {
// Early return if already valid
if err := ValidateHostname(name); err == nil {
return name, nil
}
// Transform to lowercase
name = strings.ToLower(name)
// Strip invalid DNS characters
name = invalidDNSRegex.ReplaceAllString(name, "")
return name
// Truncate to DNS label limit
if len(name) > LabelHostnameLength {
name = name[:LabelHostnameLength]
}
// Validate result after transformation
if err := ValidateHostname(name); err != nil {
return "", fmt.Errorf(
"hostname invalid after normalisation: %w",
err,
)
}
return name, nil
}
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.

View File

@@ -2,6 +2,7 @@ package util
import (
"net/netip"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -9,94 +10,173 @@ import (
"tailscale.com/util/must"
)
func TestCheckForFQDNRules(t *testing.T) {
func TestNormaliseHostname(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "valid: user",
name: "valid: lowercase user",
args: args{name: "valid-user"},
want: "valid-user",
wantErr: false,
},
{
name: "invalid: capitalized user",
name: "normalise: capitalized user",
args: args{name: "Invalid-CapItaLIzed-user"},
wantErr: true,
want: "invalid-capitalized-user",
wantErr: false,
},
{
name: "invalid: email as user",
name: "normalise: email as user",
args: args{name: "foo.bar@example.com"},
wantErr: true,
want: "foo.barexample.com",
wantErr: false,
},
{
name: "invalid: chars in user name",
name: "normalise: chars in user name",
args: args{name: "super-user+name"},
wantErr: true,
want: "super-username",
wantErr: false,
},
{
name: "invalid: too long name for user",
name: "invalid: too long name truncated leaves trailing hyphen",
args: args{
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
},
want: "",
wantErr: true,
},
{
name: "invalid: emoji stripped leaves trailing hyphen",
args: args{name: "hostname-with-💩"},
want: "",
wantErr: true,
},
{
name: "normalise: multiple emojis stripped",
args: args{name: "node-🎉-🚀-test"},
want: "node---test",
wantErr: false,
},
{
name: "invalid: only emoji becomes empty",
args: args{name: "💩"},
want: "",
wantErr: true,
},
{
name: "invalid: emoji at start leaves leading hyphen",
args: args{name: "🚀-rocket-node"},
want: "",
wantErr: true,
},
{
name: "invalid: emoji at end leaves trailing hyphen",
args: args{name: "node-test-🎉"},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
got, err := NormaliseHostname(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && got != tt.want {
t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want)
}
})
}
}
func TestConvertWithFQDNRules(t *testing.T) {
func TestValidateHostname(t *testing.T) {
tests := []struct {
name string
hostname string
dnsHostName string
name string
hostname string
wantErr bool
errorContains string
}{
{
name: "User1.test",
hostname: "User1.Test",
dnsHostName: "user1.test",
name: "valid lowercase",
hostname: "valid-hostname",
wantErr: false,
},
{
name: "User'1$2.test",
hostname: "User'1$2.Test",
dnsHostName: "user12.test",
name: "uppercase rejected",
hostname: "MyHostname",
wantErr: true,
errorContains: "must be lowercase",
},
{
name: "User-^_12.local.test",
hostname: "User-^_12.local.Test",
dnsHostName: "user-12.local.test",
name: "too short",
hostname: "a",
wantErr: true,
errorContains: "too short",
},
{
name: "User-MacBook-Pro",
hostname: "User-MacBook-Pro",
dnsHostName: "user-macbook-pro",
name: "too long",
hostname: "a" + strings.Repeat("b", 63),
wantErr: true,
errorContains: "too long",
},
{
name: "User-Linux-Ubuntu/Fedora",
hostname: "User-Linux-Ubuntu/Fedora",
dnsHostName: "user-linux-ubuntufedora",
name: "emoji rejected",
hostname: "hostname-💩",
wantErr: true,
errorContains: "invalid characters",
},
{
name: "User-[Space]123",
hostname: "User-[ ]123",
dnsHostName: "user-123",
name: "starts with hyphen",
hostname: "-hostname",
wantErr: true,
errorContains: "cannot start or end with a hyphen",
},
{
name: "ends with hyphen",
hostname: "hostname-",
wantErr: true,
errorContains: "cannot start or end with a hyphen",
},
{
name: "starts with dot",
hostname: ".hostname",
wantErr: true,
errorContains: "cannot start or end with a dot",
},
{
name: "ends with dot",
hostname: "hostname.",
wantErr: true,
errorContains: "cannot start or end with a dot",
},
{
name: "special characters",
hostname: "host!@#$name",
wantErr: true,
errorContains: "invalid characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fqdnHostName := ConvertWithFQDNRules(tt.hostname)
assert.Equal(t, tt.dnsHostName, fqdnHostName)
err := ValidateHostname(tt.hostname)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errorContains != "" {
if err == nil || !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains)
}
}
})
}
}

View File

@@ -66,6 +66,11 @@ func MustGenerateRandomStringDNSSafe(size int) string {
return hash
}
func InvalidString() string {
hash, _ := GenerateRandomStringDNSSafe(8)
return "invalid-" + hash
}
func TailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))

View File

@@ -1,6 +1,7 @@
package util
import (
"cmp"
"errors"
"fmt"
"net/netip"
@@ -264,54 +265,32 @@ func IsCI() bool {
// if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences
// and ensures nodes always have a valid hostname.
// The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123).
func SafeHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
// EnsureHostname guarantees a valid hostname for node registration.
// This function never fails - it always returns a valid hostname.
//
// Strategy:
// 1. If hostinfo is nil/empty → generate default from keys
// 2. If hostname is provided → normalise it
// 3. If normalisation fails → generate invalid-<random> replacement
//
// Returns the guaranteed-valid hostname to use.
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
if hostinfo == nil || hostinfo.Hostname == "" {
// Generate a default hostname using machine key prefix
if machineKey != "" {
keyPrefix := machineKey
if len(machineKey) > 8 {
keyPrefix = machineKey[:8]
}
return fmt.Sprintf("node-%s", keyPrefix)
key := cmp.Or(machineKey, nodeKey)
if key == "" {
return "unknown-node"
}
if nodeKey != "" {
keyPrefix := nodeKey
if len(nodeKey) > 8 {
keyPrefix = nodeKey[:8]
}
return fmt.Sprintf("node-%s", keyPrefix)
keyPrefix := key
if len(key) > 8 {
keyPrefix = key[:8]
}
return "unknown-node"
return fmt.Sprintf("node-%s", keyPrefix)
}
hostname := hostinfo.Hostname
// Validate hostname length - DNS label limit is 63 characters (RFC 1123)
// Truncate if necessary to ensure compatibility with given name generation
if len(hostname) > 63 {
hostname = hostname[:63]
lowercased := strings.ToLower(hostinfo.Hostname)
if err := ValidateHostname(lowercased); err == nil {
return lowercased
}
return hostname
}
// EnsureValidHostinfo ensures that Hostinfo is non-nil and has a valid hostname.
// If Hostinfo is nil, it creates a minimal valid Hostinfo with a generated hostname.
// Returns the validated/created Hostinfo and the extracted hostname.
func EnsureValidHostinfo(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) (*tailcfg.Hostinfo, string) {
if hostinfo == nil {
hostname := SafeHostname(nil, machineKey, nodeKey)
return &tailcfg.Hostinfo{
Hostname: hostname,
}, hostname
}
hostname := SafeHostname(hostinfo, machineKey, nodeKey)
// Update the hostname in the hostinfo if it was empty or if it was truncated
if hostinfo.Hostname == "" || hostinfo.Hostname != hostname {
hostinfo.Hostname = hostname
}
return hostinfo, hostname
return InvalidString()
}

View File

@@ -3,6 +3,7 @@ package util
import (
"errors"
"net/netip"
"strings"
"testing"
"time"
@@ -795,7 +796,7 @@ over a maximum of 30 hops:
}
}
func TestSafeHostname(t *testing.T) {
func TestEnsureHostname(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -878,7 +879,7 @@ func TestSafeHostname(t *testing.T) {
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "123456789012345678901234567890123456789012345678901234567890123",
want: "invalid-",
},
{
name: "hostname_very_long_truncated",
@@ -887,7 +888,7 @@ func TestSafeHostname(t *testing.T) {
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits",
want: "invalid-",
},
{
name: "hostname_with_special_chars",
@@ -896,7 +897,7 @@ func TestSafeHostname(t *testing.T) {
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "node-with-special!@#$%",
want: "invalid-",
},
{
name: "hostname_with_unicode",
@@ -905,7 +906,7 @@ func TestSafeHostname(t *testing.T) {
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "node-ñoño-测试",
want: "invalid-",
},
{
name: "short_machine_key",
@@ -925,20 +926,160 @@ func TestSafeHostname(t *testing.T) {
nodeKey: "short",
want: "node-short",
},
{
name: "hostname_with_emoji_replaced",
hostinfo: &tailcfg.Hostinfo{
Hostname: "hostname-with-💩",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "hostname_only_emoji_replaced",
hostinfo: &tailcfg.Hostinfo{
Hostname: "🚀",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "hostname_with_multiple_emojis_replaced",
hostinfo: &tailcfg.Hostinfo{
Hostname: "node-🎉-🚀-test",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "uppercase_to_lowercase",
hostinfo: &tailcfg.Hostinfo{
Hostname: "User2-Host",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "user2-host",
},
{
name: "underscore_removed",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test_node",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "at_sign_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "Test@Host",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "chinese_chars_with_dash_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "server-北京-01",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "chinese_only_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "我的电脑",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "emoji_with_text_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "laptop-🚀",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "mixed_chinese_emoji_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "测试💻机器",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "only_emojis_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "🎉🎊",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "only_at_signs_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "@@@",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "starts_with_dash_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "-test",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "ends_with_dash_invalid",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test-",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
{
name: "very_long_hostname_truncated",
hostinfo: &tailcfg.Hostinfo{
Hostname: strings.Repeat("t", 70),
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
want: "invalid-",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := SafeHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
if got != tt.want {
t.Errorf("SafeHostname() = %v, want %v", got, tt.want)
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.want, "invalid-") {
if !strings.HasPrefix(got, "invalid-") {
t.Errorf("EnsureHostname() = %v, want prefix %v", got, tt.want)
}
} else if got != tt.want {
t.Errorf("EnsureHostname() = %v, want %v", got, tt.want)
}
})
}
}
func TestEnsureValidHostinfo(t *testing.T) {
func TestEnsureHostnameWithHostinfo(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -976,14 +1117,6 @@ func TestEnsureValidHostinfo(t *testing.T) {
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "node-mkey1234",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "node-mkey1234" {
t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname)
}
},
},
{
name: "empty_hostname_updated",
@@ -994,37 +1127,15 @@ func TestEnsureValidHostinfo(t *testing.T) {
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "node-mkey1234",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "node-mkey1234" {
t.Errorf("hostname = %v, want node-mkey1234", hi.Hostname)
}
if hi.OS != "darwin" {
t.Errorf("OS = %v, want darwin", hi.OS)
}
},
},
{
name: "long_hostname_truncated",
name: "long_hostname_rejected",
hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters",
},
machineKey: "mkey12345678",
nodeKey: "nkey12345678",
wantHostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits",
checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "test-node-with-very-long-hostname-that-exceeds-dns-label-limits" {
t.Errorf("hostname = %v, want truncated", hi.Hostname)
}
if len(hi.Hostname) != 63 {
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
}
},
wantHostname: "invalid-",
},
{
name: "nil_hostinfo_node_key_only",
@@ -1128,23 +1239,20 @@ func TestEnsureValidHostinfo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHostinfo, gotHostname := EnsureValidHostinfo(tt.hostinfo, tt.machineKey, tt.nodeKey)
if gotHostname != tt.wantHostname {
t.Errorf("EnsureValidHostinfo() hostname = %v, want %v", gotHostname, tt.wantHostname)
}
if gotHostinfo == nil {
t.Error("returned hostinfo should never be nil")
}
if tt.checkHostinfo != nil {
tt.checkHostinfo(t, gotHostinfo)
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.wantHostname, "invalid-") {
if !strings.HasPrefix(gotHostname, "invalid-") {
t.Errorf("EnsureHostname() = %v, want prefix %v", gotHostname, tt.wantHostname)
}
} else if gotHostname != tt.wantHostname {
t.Errorf("EnsureHostname() hostname = %v, want %v", gotHostname, tt.wantHostname)
}
})
}
}
func TestSafeHostname_DNSLabelLimit(t *testing.T) {
func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
t.Parallel()
testCases := []string{
@@ -1157,7 +1265,7 @@ func TestSafeHostname_DNSLabelLimit(t *testing.T) {
for i, hostname := range testCases {
t.Run(cmp.Diff("", ""), func(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := SafeHostname(hostinfo, "mkey", "nkey")
result := EnsureHostname(hostinfo, "mkey", "nkey")
if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
}
@@ -1165,7 +1273,7 @@ func TestSafeHostname_DNSLabelLimit(t *testing.T) {
}
}
func TestEnsureValidHostinfo_Idempotent(t *testing.T) {
func TestEnsureHostname_Idempotent(t *testing.T) {
t.Parallel()
originalHostinfo := &tailcfg.Hostinfo{
@@ -1173,16 +1281,10 @@ func TestEnsureValidHostinfo_Idempotent(t *testing.T) {
OS: "linux",
}
hostinfo1, hostname1 := EnsureValidHostinfo(originalHostinfo, "mkey", "nkey")
hostinfo2, hostname2 := EnsureValidHostinfo(hostinfo1, "mkey", "nkey")
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey")
if hostname1 != hostname2 {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
}
if hostinfo1.Hostname != hostinfo2.Hostname {
t.Errorf("hostinfo hostnames not equal: %v != %v", hostinfo1.Hostname, hostinfo2.Hostname)
}
if hostinfo1.OS != hostinfo2.OS {
t.Errorf("hostinfo OS not equal: %v != %v", hostinfo1.OS, hostinfo2.OS)
}
}

View File

@@ -3,12 +3,14 @@ package integration
import (
"fmt"
"net/netip"
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/integration/hsic"
@@ -319,12 +321,14 @@ func TestACLHostsInNetMapTable(t *testing.T) {
require.NoError(t, err)
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
user := status.User[status.Self.UserID].LoginName
user := status.User[status.Self.UserID].LoginName
assert.Len(t, status.Peer, (testCase.want[user]))
assert.Len(c, status.Peer, (testCase.want[user]))
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer visibility")
}
})
}
@@ -782,75 +786,87 @@ func TestACLNamedHostsCanReach(t *testing.T) {
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
// test1 can query test3
result, err := test1.Curl(test3ip4URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test3ip4URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv4")
result, err = test1.Curl(test3ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test3ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv6")
result, err = test1.Curl(test3fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test3fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via FQDN")
// test2 can query test3
result, err = test2.Curl(test3ip4URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test3ip4URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip4URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv4")
result, err = test2.Curl(test3ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test3ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv6")
result, err = test2.Curl(test3fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test3fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s",
test3fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via FQDN")
// test3 cannot query test1
result, err = test3.Curl(test1ip4URL)
result, err := test3.Curl(test1ip4URL)
assert.Empty(t, result)
require.Error(t, err)
@@ -876,38 +892,44 @@ func TestACLNamedHostsCanReach(t *testing.T) {
require.Error(t, err)
// test1 can query test2
result, err = test1.Curl(test2ip4URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip4URL,
result,
)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ip4URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip4URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4")
require.NoError(t, err)
result, err = test1.Curl(test2ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6")
result, err = test1.Curl(test2fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN")
// test2 cannot query test1
result, err = test2.Curl(test1ip4URL)
@@ -1050,50 +1072,63 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
// test1 can query test2
result, err := test1.Curl(test2ipURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ipURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ipURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ipURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4")
result, err = test1.Curl(test2ip6URL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2ip6URL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2ip6URL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6")
result, err = test1.Curl(test2fqdnURL)
assert.Lenf(
t,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test1.Curl(test2fqdnURL)
assert.NoError(c, err)
assert.Lenf(
c,
result,
13,
"failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s",
test2fqdnURL,
result,
)
}, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN")
result, err = test2.Curl(test1ipURL)
assert.Empty(t, result)
require.Error(t, err)
// test2 cannot query test1 (negative test case)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test1ipURL)
assert.Error(c, err)
assert.Empty(c, result)
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv4")
result, err = test2.Curl(test1ip6URL)
assert.Empty(t, result)
require.Error(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test1ip6URL)
assert.Error(c, err)
assert.Empty(c, result)
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv6")
result, err = test2.Curl(test1fqdnURL)
assert.Empty(t, result)
require.Error(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := test2.Curl(test1fqdnURL)
assert.Error(c, err)
assert.Empty(c, result)
}, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via FQDN")
})
}
}
@@ -1266,9 +1301,15 @@ func TestACLAutogroupMember(t *testing.T) {
// Test that untagged nodes can access each other
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
var clientIsUntagged bool
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
clientIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0
assert.True(c, clientIsUntagged, "Expected client %s to be untagged for autogroup:member test", client.Hostname())
}, 10*time.Second, 200*time.Millisecond, "Waiting for client %s to be untagged", client.Hostname())
if !clientIsUntagged {
continue
}
@@ -1277,9 +1318,15 @@ func TestACLAutogroupMember(t *testing.T) {
continue
}
status, err := peer.Status()
require.NoError(t, err)
if status.Self.Tags != nil && status.Self.Tags.Len() > 0 {
var peerIsUntagged bool
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := peer.Status()
assert.NoError(c, err)
peerIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0
assert.True(c, peerIsUntagged, "Expected peer %s to be untagged for autogroup:member test", peer.Hostname())
}, 10*time.Second, 200*time.Millisecond, "Waiting for peer %s to be untagged", peer.Hostname())
if !peerIsUntagged {
continue
}
@@ -1468,21 +1515,23 @@ func TestACLAutogroupTagged(t *testing.T) {
// Explicitly verify tags on tagged nodes
for _, client := range taggedClients {
status, err := client.Status()
require.NoError(t, err)
require.NotNil(t, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
require.Positive(t, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
t.Logf("Tagged node %s has tags: %v", client.Hostname(), status.Self.Tags)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
assert.NotNil(c, status.Self.Tags, "tagged node %s should have tags", client.Hostname())
assert.Positive(c, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname())
}, 10*time.Second, 200*time.Millisecond, "Waiting for tags to be applied to tagged nodes")
}
// Verify untagged nodes have no tags
for _, client := range untaggedClients {
status, err := client.Status()
require.NoError(t, err)
if status.Self.Tags != nil {
require.Equal(t, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
}
t.Logf("Untagged node %s has no tags", client.Hostname())
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
if status.Self.Tags != nil {
assert.Equal(c, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname())
}
}, 10*time.Second, 200*time.Millisecond, "Waiting to verify untagged nodes have no tags")
}
// Test that tagged nodes can communicate with each other
@@ -1603,9 +1652,11 @@ func TestACLAutogroupSelf(t *testing.T) {
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s (user1) to %s (user1)", client.Hostname(), fqdn)
result, err := client.Curl(url)
assert.Len(t, result, 13)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device")
}
}
@@ -1622,9 +1673,11 @@ func TestACLAutogroupSelf(t *testing.T) {
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
t.Logf("url from %s (user2) to %s (user2)", client.Hostname(), fqdn)
result, err := client.Curl(url)
assert.Len(t, result, 13)
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device")
}
}
@@ -1657,3 +1710,388 @@ func TestACLAutogroupSelf(t *testing.T) {
}
}
}
func TestACLPolicyPropagationOverTime(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 2,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
// Install iptables to enable packet filtering for ACL tests.
// Packet filters are essential for testing autogroup:self and other ACL policies.
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add python3 curl iptables ip6tables ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
}),
tsic.WithDockerWorkdir("/"),
},
hsic.WithTestName("aclpropagation"),
hsic.WithPolicyMode(types.PolicyModeDB),
)
require.NoError(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
require.NoError(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
require.NoError(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
require.NoError(t, err)
allClients := append(user1Clients, user2Clients...)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Define the four policies we'll cycle through
allowAllPolicy := &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
}
autogroupSelfPolicy := &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny),
},
},
},
}
user1ToUser2Policy := &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{usernamep("user1@")},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny),
},
},
},
}
// Run through the policy cycle 5 times
for i := range 5 {
iteration := i + 1 // range 5 gives 0-4, we want 1-5 for logging
t.Logf("=== Iteration %d/5 ===", iteration)
// Phase 1: Allow all policy
t.Logf("Iteration %d: Setting allow-all policy", iteration)
err = headscale.SetPolicy(allowAllPolicy)
require.NoError(t, err)
// Wait for peer lists to sync with allow-all policy
t.Logf("Iteration %d: Phase 1 - Waiting for peer lists to sync with allow-all policy", iteration)
err = scenario.WaitForTailscaleSync()
require.NoError(t, err, "iteration %d: Phase 1 - failed to sync after allow-all policy", iteration)
// Test all-to-all connectivity after state is settled
t.Logf("Iteration %d: Phase 1 - Testing all-to-all connectivity", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
for _, peer := range allClients {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: %s should reach %s with allow-all policy", iteration, client.Hostname(), fqdn)
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn)
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 1 - all connectivity tests with allow-all policy", iteration)
// Phase 2: Autogroup:self policy (only same user can access)
t.Logf("Iteration %d: Phase 2 - Setting autogroup:self policy", iteration)
err = headscale.SetPolicy(autogroupSelfPolicy)
require.NoError(t, err)
// Wait for peer lists to sync with autogroup:self - ensures cross-user peers are removed
t.Logf("Iteration %d: Phase 2 - Waiting for peer lists to sync with autogroup:self", iteration)
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
require.NoError(t, err, "iteration %d: Phase 2 - failed to sync after autogroup:self policy", iteration)
// Test ALL connectivity (positive and negative) in one block after state is settled
t.Logf("Iteration %d: Phase 2 - Testing all connectivity with autogroup:self", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Positive: user1 can access user1's nodes
for _, client := range user1Clients {
for _, peer := range user1Clients {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
}
}
// Positive: user2 can access user2's nodes
for _, client := range user2Clients {
for _, peer := range user2Clients {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user2 %s should reach user2's node %s", iteration, client.Hostname(), fqdn)
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn)
}
}
// Negative: user1 cannot access user2's nodes
for _, client := range user1Clients {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user1 %s should NOT reach user2's node %s with autogroup:self", iteration, client.Hostname(), fqdn)
assert.Empty(ct, result, "iteration %d: user1 %s->user2 %s should fail", iteration, client.Hostname(), fqdn)
}
}
// Negative: user2 cannot access user1's nodes
for _, client := range user2Clients {
for _, peer := range user1Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Empty(ct, result, "iteration %d: user2->user1 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2 - all connectivity tests with autogroup:self", iteration)
// Phase 2b: Add a new node to user1 and validate policy propagation
t.Logf("Iteration %d: Phase 2b - Adding new node to user1 during autogroup:self policy", iteration)
// Add a new node with the same options as the initial setup
// Get the network to use (scenario uses first network in list)
networks := scenario.Networks()
require.NotEmpty(t, networks, "scenario should have at least one network")
newClient := scenario.MustAddAndLoginClient(t, "user1", "all", headscale,
tsic.WithNetfilter("off"),
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev",
}),
tsic.WithDockerWorkdir("/"),
tsic.WithNetwork(networks[0]),
)
t.Logf("Iteration %d: Phase 2b - Added and logged in new node %s", iteration, newClient.Hostname())
// Wait for peer lists to sync after new node addition (now 3 user1 nodes, still autogroup:self)
t.Logf("Iteration %d: Phase 2b - Waiting for peer lists to sync after new node addition", iteration)
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
require.NoError(t, err, "iteration %d: Phase 2b - failed to sync after new node addition", iteration)
// Test ALL connectivity (positive and negative) in one block after state is settled
t.Logf("Iteration %d: Phase 2b - Testing all connectivity after new node addition", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Re-fetch client list to ensure latest state
user1ClientsWithNew, err := scenario.ListTailscaleClients("user1")
assert.NoError(ct, err, "iteration %d: failed to list user1 clients", iteration)
assert.Len(ct, user1ClientsWithNew, 3, "iteration %d: user1 should have 3 nodes", iteration)
// Positive: all user1 nodes can access each other
for _, client := range user1ClientsWithNew {
for _, peer := range user1ClientsWithNew {
if client.ContainerID() == peer.ContainerID() {
continue
}
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
}
}
// Negative: user1 nodes cannot access user2's nodes
for _, client := range user1ClientsWithNew {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user1 node %s should NOT reach user2 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Empty(ct, result, "iteration %d: user1->user2 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - all connectivity tests after new node addition", iteration)
// Delete the newly added node before Phase 3
t.Logf("Iteration %d: Phase 2b - Deleting the newly added node from user1", iteration)
// Get the node list and find the newest node (highest ID)
var nodeList []*v1.Node
var nodeToDeleteID uint64
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
nodeList, err = headscale.ListNodes("user1")
assert.NoError(ct, err)
assert.Len(ct, nodeList, 3, "should have 3 user1 nodes before deletion")
// Find the node with the highest ID (the newest one)
for _, node := range nodeList {
if node.GetId() > nodeToDeleteID {
nodeToDeleteID = node.GetId()
}
}
}, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - listing nodes before deletion", iteration)
// Delete the node via headscale helper
t.Logf("Iteration %d: Phase 2b - Deleting node ID %d from headscale", iteration, nodeToDeleteID)
err = headscale.DeleteNode(nodeToDeleteID)
require.NoError(t, err, "iteration %d: failed to delete node %d", iteration, nodeToDeleteID)
// Remove the deleted client from the scenario's user.Clients map
// This is necessary for WaitForTailscaleSyncPerUser to calculate correct peer counts
t.Logf("Iteration %d: Phase 2b - Removing deleted client from scenario", iteration)
for clientName, client := range scenario.users["user1"].Clients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
if err != nil {
continue
}
if nodeID == nodeToDeleteID {
delete(scenario.users["user1"].Clients, clientName)
t.Logf("Iteration %d: Phase 2b - Removed client %s (node ID %d) from scenario", iteration, clientName, nodeToDeleteID)
break
}
}
// Verify the node has been deleted
t.Logf("Iteration %d: Phase 2b - Verifying node deletion (expecting 2 user1 nodes)", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
nodeListAfter, err := headscale.ListNodes("user1")
assert.NoError(ct, err, "failed to list nodes after deletion")
assert.Len(ct, nodeListAfter, 2, "iteration %d: should have 2 user1 nodes after deletion, got %d", iteration, len(nodeListAfter))
}, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - node should be deleted", iteration)
// Wait for sync after deletion to ensure peer counts are correct
// Use WaitForTailscaleSyncPerUser because autogroup:self is still active,
// so nodes only see same-user peers, not all nodes
t.Logf("Iteration %d: Phase 2b - Waiting for sync after node deletion (with autogroup:self)", iteration)
err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond)
require.NoError(t, err, "iteration %d: failed to sync after node deletion", iteration)
// Refresh client lists after deletion to ensure we don't reference the deleted node
user1Clients, err = scenario.ListTailscaleClients("user1")
require.NoError(t, err, "iteration %d: failed to refresh user1 client list after deletion", iteration)
user2Clients, err = scenario.ListTailscaleClients("user2")
require.NoError(t, err, "iteration %d: failed to refresh user2 client list after deletion", iteration)
// Create NEW slice instead of appending to old allClients which still has deleted client
allClients = make([]TailscaleClient, 0, len(user1Clients)+len(user2Clients))
allClients = append(allClients, user1Clients...)
allClients = append(allClients, user2Clients...)
t.Logf("Iteration %d: Phase 2b completed - New node added, validated, and removed successfully", iteration)
// Phase 3: User1 can access user2 but not reverse
t.Logf("Iteration %d: Phase 3 - Setting user1->user2 directional policy", iteration)
err = headscale.SetPolicy(user1ToUser2Policy)
require.NoError(t, err)
// Note: Cannot use WaitForTailscaleSync() here because directional policy means
// user2 nodes don't see user1 nodes in their peer list (asymmetric visibility).
// The EventuallyWithT block below will handle waiting for policy propagation.
// Test ALL connectivity (positive and negative) in one block after policy settles
t.Logf("Iteration %d: Phase 3 - Testing all connectivity with directional policy", iteration)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Positive: user1 can access user2's nodes
for _, client := range user1Clients {
for _, peer := range user2Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.NoError(ct, err, "iteration %d: user1 node %s should reach user2 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname())
}
}
// Negative: user2 cannot access user1's nodes
for _, client := range user2Clients {
for _, peer := range user1Clients {
fqdn, err := peer.FQDN()
if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) {
continue
}
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
result, err := client.Curl(url)
assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname())
assert.Empty(ct, result, "iteration %d: user2->user1 from %s to %s should fail", iteration, client.Hostname(), peer.Hostname())
}
}
}, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 3 - all connectivity tests with directional policy", iteration)
t.Logf("=== Iteration %d/5 completed successfully - All 3 phases passed ===", iteration)
}
t.Log("All 5 iterations completed successfully - ACL propagation is working correctly")
}

View File

@@ -0,0 +1,657 @@
package integration
import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protojson"
)
// TestAPIAuthenticationBypass tests that the API authentication middleware
// properly blocks unauthorized requests and does not leak sensitive data.
// This test reproduces the security issue described in:
// - https://github.com/juanfont/headscale/issues/2809
// - https://github.com/juanfont/headscale/pull/2810
//
// The bug: When authentication fails, the middleware writes "Unauthorized"
// but doesn't return early, allowing the handler to execute and append
// sensitive data to the response.
func TestAPIAuthenticationBypass(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"user1", "user2", "user3"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("apiauthbypass"))
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Create an API key using the CLI
var validAPIKey string
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
apiKeyOutput, err := headscale.Execute(
[]string{
"headscale",
"apikeys",
"create",
"--expiration",
"24h",
},
)
assert.NoError(ct, err)
assert.NotEmpty(ct, apiKeyOutput)
validAPIKey = strings.TrimSpace(apiKeyOutput)
}, 20*time.Second, 1*time.Second)
// Get the API endpoint
endpoint := headscale.GetEndpoint()
apiURL := fmt.Sprintf("%s/api/v1/user", endpoint)
// Create HTTP client
client := &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec
},
}
t.Run("HTTP_NoAuthHeader", func(t *testing.T) {
// Test 1: Request without any Authorization header
// Expected: Should return 401 with ONLY "Unauthorized" text, no user data
req, err := http.NewRequest("GET", apiURL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// Should return 401 Unauthorized
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
"Expected 401 status code for request without auth header")
bodyStr := string(body)
// Should contain "Unauthorized" message
assert.Contains(t, bodyStr, "Unauthorized",
"Response should contain 'Unauthorized' message")
// Should NOT contain user data after "Unauthorized"
// This is the security bypass - if users array is present, auth was bypassed
var jsonCheck map[string]interface{}
jsonErr := json.Unmarshal(body, &jsonCheck)
// If we can unmarshal JSON and it contains "users", that's the bypass
if jsonErr == nil {
assert.NotContains(t, jsonCheck, "users",
"SECURITY ISSUE: Response should NOT contain 'users' data when unauthorized")
assert.NotContains(t, jsonCheck, "user",
"SECURITY ISSUE: Response should NOT contain 'user' data when unauthorized")
}
// Additional check: response should not contain "user1", "user2", "user3"
assert.NotContains(t, bodyStr, "user1",
"SECURITY ISSUE: Response should NOT leak user 'user1' data")
assert.NotContains(t, bodyStr, "user2",
"SECURITY ISSUE: Response should NOT leak user 'user2' data")
assert.NotContains(t, bodyStr, "user3",
"SECURITY ISSUE: Response should NOT leak user 'user3' data")
// Response should be minimal, just "Unauthorized"
// Allow some variation in response format but body should be small
assert.Less(t, len(bodyStr), 100,
"SECURITY ISSUE: Unauthorized response body should be minimal, got: %s", bodyStr)
})
t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) {
// Test 2: Request with invalid Authorization header (missing "Bearer " prefix)
// Expected: Should return 401 with ONLY "Unauthorized" text, no user data
req, err := http.NewRequest("GET", apiURL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "InvalidToken")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
"Expected 401 status code for invalid auth header format")
bodyStr := string(body)
assert.Contains(t, bodyStr, "Unauthorized")
// Should not leak user data
assert.NotContains(t, bodyStr, "user1",
"SECURITY ISSUE: Response should NOT leak user data")
assert.NotContains(t, bodyStr, "user2",
"SECURITY ISSUE: Response should NOT leak user data")
assert.NotContains(t, bodyStr, "user3",
"SECURITY ISSUE: Response should NOT leak user data")
assert.Less(t, len(bodyStr), 100,
"SECURITY ISSUE: Unauthorized response should be minimal")
})
t.Run("HTTP_InvalidBearerToken", func(t *testing.T) {
// Test 3: Request with Bearer prefix but invalid token
// Expected: Should return 401 with ONLY "Unauthorized" text, no user data
// Note: Both malformed and properly formatted invalid tokens should return 401
req, err := http.NewRequest("GET", apiURL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer invalid-token-12345")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode,
"Expected 401 status code for invalid bearer token")
bodyStr := string(body)
assert.Contains(t, bodyStr, "Unauthorized")
// Should not leak user data
assert.NotContains(t, bodyStr, "user1",
"SECURITY ISSUE: Response should NOT leak user data")
assert.NotContains(t, bodyStr, "user2",
"SECURITY ISSUE: Response should NOT leak user data")
assert.NotContains(t, bodyStr, "user3",
"SECURITY ISSUE: Response should NOT leak user data")
assert.Less(t, len(bodyStr), 100,
"SECURITY ISSUE: Unauthorized response should be minimal")
})
t.Run("HTTP_ValidAPIKey", func(t *testing.T) {
// Test 4: Request with valid API key
// Expected: Should return 200 with user data (this is the authorized case)
req, err := http.NewRequest("GET", apiURL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey))
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// Should succeed with valid auth
assert.Equal(t, http.StatusOK, resp.StatusCode,
"Expected 200 status code with valid API key")
// Should be able to parse as protobuf JSON
var response v1.ListUsersResponse
err = protojson.Unmarshal(body, &response)
assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key")
// Should contain our test users
users := response.GetUsers()
assert.Len(t, users, 3, "Should have 3 users")
userNames := make([]string, len(users))
for i, u := range users {
userNames[i] = u.GetName()
}
assert.Contains(t, userNames, "user1")
assert.Contains(t, userNames, "user2")
assert.Contains(t, userNames, "user3")
})
}
// TestAPIAuthenticationBypassCurl tests the same security issue using curl
// from inside a container, which is closer to how the issue was discovered.
func TestAPIAuthenticationBypassCurl(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"testuser1", "testuser2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("apiauthcurl"))
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Create a valid API key
apiKeyOutput, err := headscale.Execute(
[]string{
"headscale",
"apikeys",
"create",
"--expiration",
"24h",
},
)
require.NoError(t, err)
validAPIKey := strings.TrimSpace(apiKeyOutput)
endpoint := headscale.GetEndpoint()
apiURL := fmt.Sprintf("%s/api/v1/user", endpoint)
t.Run("Curl_NoAuth", func(t *testing.T) {
// Execute curl from inside the headscale container without auth
curlOutput, err := headscale.Execute(
[]string{
"curl",
"-s",
"-w",
"\nHTTP_CODE:%{http_code}",
apiURL,
},
)
require.NoError(t, err)
// Parse the output
lines := strings.Split(curlOutput, "\n")
var httpCode string
var responseBody string
for _, line := range lines {
if strings.HasPrefix(line, "HTTP_CODE:") {
httpCode = strings.TrimPrefix(line, "HTTP_CODE:")
} else {
responseBody += line
}
}
// Should return 401
assert.Equal(t, "401", httpCode,
"Curl without auth should return 401")
// Should contain Unauthorized
assert.Contains(t, responseBody, "Unauthorized",
"Response should contain 'Unauthorized'")
// Should NOT leak user data
assert.NotContains(t, responseBody, "testuser1",
"SECURITY ISSUE: Should not leak user data")
assert.NotContains(t, responseBody, "testuser2",
"SECURITY ISSUE: Should not leak user data")
// Response should be small (just "Unauthorized")
assert.Less(t, len(responseBody), 100,
"SECURITY ISSUE: Unauthorized response should be minimal, got: %s", responseBody)
})
t.Run("Curl_InvalidAuth", func(t *testing.T) {
// Execute curl with invalid auth header
curlOutput, err := headscale.Execute(
[]string{
"curl",
"-s",
"-H",
"Authorization: InvalidToken",
"-w",
"\nHTTP_CODE:%{http_code}",
apiURL,
},
)
require.NoError(t, err)
lines := strings.Split(curlOutput, "\n")
var httpCode string
var responseBody string
for _, line := range lines {
if strings.HasPrefix(line, "HTTP_CODE:") {
httpCode = strings.TrimPrefix(line, "HTTP_CODE:")
} else {
responseBody += line
}
}
assert.Equal(t, "401", httpCode)
assert.Contains(t, responseBody, "Unauthorized")
assert.NotContains(t, responseBody, "testuser1",
"SECURITY ISSUE: Should not leak user data")
assert.NotContains(t, responseBody, "testuser2",
"SECURITY ISSUE: Should not leak user data")
})
t.Run("Curl_ValidAuth", func(t *testing.T) {
// Execute curl with valid API key
curlOutput, err := headscale.Execute(
[]string{
"curl",
"-s",
"-H",
fmt.Sprintf("Authorization: Bearer %s", validAPIKey),
"-w",
"\nHTTP_CODE:%{http_code}",
apiURL,
},
)
require.NoError(t, err)
lines := strings.Split(curlOutput, "\n")
var httpCode string
var responseBody string
for _, line := range lines {
if strings.HasPrefix(line, "HTTP_CODE:") {
httpCode = strings.TrimPrefix(line, "HTTP_CODE:")
} else {
responseBody += line
}
}
// Should succeed
assert.Equal(t, "200", httpCode,
"Curl with valid API key should return 200")
// Should contain user data
var response v1.ListUsersResponse
err = protojson.Unmarshal([]byte(responseBody), &response)
assert.NoError(t, err, "Response should be valid protobuf JSON")
users := response.GetUsers()
assert.Len(t, users, 2, "Should have 2 users")
})
}
// TestGRPCAuthenticationBypass tests that the gRPC authentication interceptor
// properly blocks unauthorized requests.
// This test verifies that the gRPC API does not have the same bypass issue
// as the HTTP API middleware.
func TestGRPCAuthenticationBypass(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"grpcuser1", "grpcuser2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
// We need TLS for remote gRPC connections
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("grpcauthtest"),
hsic.WithTLS(),
hsic.WithConfigEnv(map[string]string{
// Enable gRPC on the standard port
"HEADSCALE_GRPC_LISTEN_ADDR": "0.0.0.0:50443",
}),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Create a valid API key
apiKeyOutput, err := headscale.Execute(
[]string{
"headscale",
"apikeys",
"create",
"--expiration",
"24h",
},
)
require.NoError(t, err)
validAPIKey := strings.TrimSpace(apiKeyOutput)
// Get the gRPC endpoint
// For gRPC, we need to use the hostname and port 50443
grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname())
t.Run("gRPC_NoAPIKey", func(t *testing.T) {
// Test 1: Try to use CLI without API key (should fail)
// When HEADSCALE_CLI_ADDRESS is set but HEADSCALE_CLI_API_KEY is not set,
// the CLI should fail immediately
_, err := headscale.Execute(
[]string{
"sh", "-c",
fmt.Sprintf("HEADSCALE_CLI_ADDRESS=%s HEADSCALE_CLI_INSECURE=true headscale users list --output json 2>&1", grpcAddress),
},
)
// Should fail - CLI exits when API key is missing
assert.Error(t, err,
"gRPC connection without API key should fail")
})
t.Run("gRPC_InvalidAPIKey", func(t *testing.T) {
// Test 2: Try to use CLI with invalid API key (should fail with auth error)
output, err := headscale.Execute(
[]string{
"sh", "-c",
fmt.Sprintf("HEADSCALE_CLI_ADDRESS=%s HEADSCALE_CLI_API_KEY=invalid-key-12345 HEADSCALE_CLI_INSECURE=true headscale users list --output json 2>&1", grpcAddress),
},
)
// Should fail with authentication error
assert.Error(t, err,
"gRPC connection with invalid API key should fail")
// Should contain authentication error message
outputStr := strings.ToLower(output)
assert.True(t,
strings.Contains(outputStr, "unauthenticated") ||
strings.Contains(outputStr, "invalid token") ||
strings.Contains(outputStr, "failed to validate token") ||
strings.Contains(outputStr, "authentication"),
"Error should indicate authentication failure, got: %s", output)
// Should NOT leak user data
assert.NotContains(t, output, "grpcuser1",
"SECURITY ISSUE: gRPC should not leak user data with invalid auth")
assert.NotContains(t, output, "grpcuser2",
"SECURITY ISSUE: gRPC should not leak user data with invalid auth")
})
t.Run("gRPC_ValidAPIKey", func(t *testing.T) {
// Test 3: Use CLI with valid API key (should succeed)
output, err := headscale.Execute(
[]string{
"sh", "-c",
fmt.Sprintf("HEADSCALE_CLI_ADDRESS=%s HEADSCALE_CLI_API_KEY=%s HEADSCALE_CLI_INSECURE=true headscale users list --output json", grpcAddress, validAPIKey),
},
)
// Should succeed
assert.NoError(t, err,
"gRPC connection with valid API key should succeed, output: %s", output)
// CLI outputs the users array directly, not wrapped in ListUsersResponse
// Parse as JSON array (CLI uses json.Marshal, not protojson)
var users []*v1.User
err = json.Unmarshal([]byte(output), &users)
assert.NoError(t, err, "Response should be valid JSON array")
assert.Len(t, users, 2, "Should have 2 users")
userNames := make([]string, len(users))
for i, u := range users {
userNames[i] = u.GetName()
}
assert.Contains(t, userNames, "grpcuser1")
assert.Contains(t, userNames, "grpcuser2")
})
}
// TestCLIWithConfigAuthenticationBypass tests that the headscale CLI
// with --config flag does not have authentication bypass issues when
// connecting to a remote server.
// Note: When using --config with local unix socket, no auth is needed.
// This test focuses on remote gRPC connections which require API keys.
func TestCLIWithConfigAuthenticationBypass(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"cliuser1", "cliuser2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("cliconfigauth"),
hsic.WithTLS(),
hsic.WithConfigEnv(map[string]string{
"HEADSCALE_GRPC_LISTEN_ADDR": "0.0.0.0:50443",
}),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Create a valid API key
apiKeyOutput, err := headscale.Execute(
[]string{
"headscale",
"apikeys",
"create",
"--expiration",
"24h",
},
)
require.NoError(t, err)
validAPIKey := strings.TrimSpace(apiKeyOutput)
grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname())
// Create a config file for testing
configWithoutKey := fmt.Sprintf(`
cli:
address: %s
timeout: 5s
insecure: true
`, grpcAddress)
configWithInvalidKey := fmt.Sprintf(`
cli:
address: %s
api_key: invalid-key-12345
timeout: 5s
insecure: true
`, grpcAddress)
configWithValidKey := fmt.Sprintf(`
cli:
address: %s
api_key: %s
timeout: 5s
insecure: true
`, grpcAddress, validAPIKey)
t.Run("CLI_Config_NoAPIKey", func(t *testing.T) {
// Create config file without API key
err := headscale.WriteFile("/tmp/config_no_key.yaml", []byte(configWithoutKey))
require.NoError(t, err)
// Try to use CLI with config that has no API key
_, err = headscale.Execute(
[]string{
"headscale",
"--config", "/tmp/config_no_key.yaml",
"users", "list",
"--output", "json",
},
)
// Should fail
assert.Error(t, err,
"CLI with config missing API key should fail")
})
t.Run("CLI_Config_InvalidAPIKey", func(t *testing.T) {
// Create config file with invalid API key
err := headscale.WriteFile("/tmp/config_invalid_key.yaml", []byte(configWithInvalidKey))
require.NoError(t, err)
// Try to use CLI with invalid API key
output, err := headscale.Execute(
[]string{
"sh", "-c",
"headscale --config /tmp/config_invalid_key.yaml users list --output json 2>&1",
},
)
// Should fail
assert.Error(t, err,
"CLI with invalid API key should fail")
// Should indicate authentication failure
outputStr := strings.ToLower(output)
assert.True(t,
strings.Contains(outputStr, "unauthenticated") ||
strings.Contains(outputStr, "invalid token") ||
strings.Contains(outputStr, "failed to validate token") ||
strings.Contains(outputStr, "authentication"),
"Error should indicate authentication failure, got: %s", output)
// Should NOT leak user data
assert.NotContains(t, output, "cliuser1",
"SECURITY ISSUE: CLI should not leak user data with invalid auth")
assert.NotContains(t, output, "cliuser2",
"SECURITY ISSUE: CLI should not leak user data with invalid auth")
})
t.Run("CLI_Config_ValidAPIKey", func(t *testing.T) {
// Create config file with valid API key
err := headscale.WriteFile("/tmp/config_valid_key.yaml", []byte(configWithValidKey))
require.NoError(t, err)
// Use CLI with valid API key
output, err := headscale.Execute(
[]string{
"headscale",
"--config", "/tmp/config_valid_key.yaml",
"users", "list",
"--output", "json",
},
)
// Should succeed
assert.NoError(t, err,
"CLI with valid API key should succeed")
// CLI outputs the users array directly, not wrapped in ListUsersResponse
// Parse as JSON array (CLI uses json.Marshal, not protojson)
var users []*v1.User
err = json.Unmarshal([]byte(output), &users)
assert.NoError(t, err, "Response should be valid JSON array")
assert.Len(t, users, 2, "Should have 2 users")
userNames := make([]string, len(users))
for i, u := range users {
userNames[i] = u.GetName()
}
assert.Contains(t, userNames, "cliuser1")
assert.Contains(t, userNames, "cliuser2")
})
}

View File

@@ -74,14 +74,21 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
clientIPs[client] = ips
}
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
for _, node := range listNodes {
assertLastSeenSetWithCollect(c, node)
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
nodeCountBeforeLogout = len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
for _, client := range allClients {
err := client.Logout()
@@ -188,11 +195,16 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
}
listNodes, err = headscale.ListNodes()
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSetWithCollect(c, node)
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for node list after relogin")
})
}
}
@@ -238,9 +250,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
nodeCountBeforeLogout = len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
for _, client := range allClients {
@@ -371,9 +390,16 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
listNodes, err := headscale.ListNodes()
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout")
nodeCountBeforeLogout = len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
for _, client := range allClients {

View File

@@ -901,15 +901,18 @@ func TestOIDCFollowUpUrl(t *testing.T) {
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
time.Sleep(2 * time.Minute)
st, err := ts.Status()
require.NoError(t, err)
assert.Equal(t, "NeedsLogin", st.BackendState)
var newUrl *url.URL
assert.EventuallyWithT(t, func(c *assert.CollectT) {
st, err := ts.Status()
assert.NoError(c, err)
assert.Equal(c, "NeedsLogin", st.BackendState)
// get new AuthURL from daemon
newUrl, err := url.Parse(st.AuthURL)
require.NoError(t, err)
// get new AuthURL from daemon
newUrl, err = url.Parse(st.AuthURL)
assert.NoError(c, err)
assert.NotEqual(t, u.String(), st.AuthURL, "AuthURL should change")
assert.NotEqual(c, u.String(), st.AuthURL, "AuthURL should change")
}, 10*time.Second, 200*time.Millisecond, "Waiting for registration cache to expire and status to reflect NeedsLogin")
_, err = doLoginURL(ts.Hostname(), newUrl)
require.NoError(t, err)
@@ -943,9 +946,11 @@ func TestOIDCFollowUpUrl(t *testing.T) {
t.Fatalf("unexpected users: %s", diff)
}
listNodes, err := headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, listNodes, 1)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
listNodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, 1)
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
}
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,7 @@ type ControlServer interface {
CreateUser(user string) (*v1.User, error)
CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
ListNodes(users ...string) ([]*v1.Node, error)
DeleteNode(nodeID uint64) error
NodesByUser() (map[string][]*v1.Node, error)
NodesByName() (map[string]*v1.Node, error)
ListUsers() ([]*v1.User, error)
@@ -38,4 +39,5 @@ type ControlServer interface {
PrimaryRoutes() (*routes.DebugRoutes, error)
DebugBatcher() (*hscontrol.DebugBatcherInfo, error)
DebugNodeStore() (map[types.NodeID]types.Node, error)
DebugFilter() ([]tailcfg.FilterRule, error)
}

View File

@@ -86,6 +86,108 @@ func TestPingAllByIP(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
// TestPingAllByIPRandomClientPort is a variant of TestPingAllByIP that validates
// direct connections between nodes with randomize_client_port enabled. This test
// ensures that nodes can establish direct peer-to-peer connections without relying
// on DERP relay servers, and that the randomize_client_port feature works correctly.
func TestPingAllByIPRandomClientPort(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1", "user2"},
MaxWait: dockertestMaxWait(),
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithTestName("pingdirect"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom),
hsic.WithConfigEnv(map[string]string{
"HEADSCALE_RANDOMIZE_CLIENT_PORT": "true",
}),
)
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
requireNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
hs, err := scenario.Headscale()
require.NoError(t, err)
// Extract node IDs for validation
expectedNodes := make([]types.NodeID, 0, len(allClients))
for _, client := range allClients {
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
require.NoError(t, err, "failed to parse node ID")
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
requireAllClientsOnline(t, hs, expectedNodes, true, "all clients should be online across all systems", 30*time.Second)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
// Perform pings to establish connections
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
// Validate that connections are direct (not relayed through DERP)
// We check that each client has direct connections to its peers
t.Logf("Validating direct connections...")
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err, "failed to get status for client %s", client.Hostname())
if err != nil {
continue
}
// Check each peer to see if we have a direct connection
directCount := 0
relayedCount := 0
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
// CurAddr indicates the current address being used to communicate with this peer
// Direct connections have CurAddr set to an actual IP:port
// DERP-relayed connections either have no CurAddr or it contains the DERP magic IP
if peerStatus.CurAddr != "" && !strings.Contains(peerStatus.CurAddr, "127.3.3.40") {
// This is a direct connection - CurAddr contains the actual peer IP:port
directCount++
t.Logf("Client %s -> Peer %s: DIRECT connection via %s (relay: %s)",
client.Hostname(), peerStatus.HostName, peerStatus.CurAddr, peerStatus.Relay)
} else {
// This is a relayed connection through DERP
relayedCount++
t.Logf("Client %s -> Peer %s: RELAYED connection (CurAddr: %s, relay: %s)",
client.Hostname(), peerStatus.HostName, peerStatus.CurAddr, peerStatus.Relay)
}
}
// Assert that we have at least some direct connections
// In a local Docker network, we should be able to establish direct connections
assert.Greater(ct, directCount, 0,
"Client %s should have at least one direct connection, got %d direct and %d relayed",
client.Hostname(), directCount, relayedCount)
}
}, 60*time.Second, 2*time.Second, "validating direct connections between peers")
}
func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t)
@@ -514,7 +616,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
hostnames := map[string]string{
"1": "user1-host",
"2": "User2-Host",
"2": "user2-host",
"3": "user3-host",
}
@@ -541,8 +643,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
// update hostnames using the up command
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
status := client.MustStatus()
command := []string{
"tailscale",
@@ -577,7 +678,11 @@ func TestUpdateHostnameFromClient(t *testing.T) {
for _, node := range nodes {
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
assert.Equal(ct, hostname, node.GetName(), "Node name should match hostname")
assert.Equal(ct, util.ConvertWithFQDNRules(hostname), node.GetGivenName(), "Given name should match FQDN rules")
// GivenName is normalized (lowercase, invalid chars stripped)
normalised, err := util.NormaliseHostname(hostname)
assert.NoError(ct, err)
assert.Equal(ct, normalised, node.GetGivenName(), "Given name should match FQDN rules")
}
}, 20*time.Second, 1*time.Second)
@@ -638,8 +743,7 @@ func TestUpdateHostnameFromClient(t *testing.T) {
}, 60*time.Second, 2*time.Second)
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
status := client.MustStatus()
command := []string{
"tailscale",
@@ -675,12 +779,13 @@ func TestUpdateHostnameFromClient(t *testing.T) {
for _, node := range nodes {
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
givenName := fmt.Sprintf("%d-givenname", node.GetId())
if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName {
// Hostnames are lowercased before being stored, so "NEW" becomes "new"
if node.GetName() != hostname+"new" || node.GetGivenName() != givenName {
return false
}
}
return true
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix")
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with new suffix")
}
func TestExpireNode(t *testing.T) {
@@ -768,26 +873,25 @@ func TestExpireNode(t *testing.T) {
// Verify that the expired node has been marked in all peers list.
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
if client.Hostname() == node.GetName() {
continue
}
if client.Hostname() != node.GetName() {
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
// Ensures that the node is present, and that it is expired.
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
requireNotNil(t, peerStatus.Expired)
assert.NotNil(t, peerStatus.KeyExpiry)
peerStatus, ok := status.Peer[expiredNodeKey]
assert.True(c, ok, "expired node key should be present in peer list")
if ok {
assert.NotNil(c, peerStatus.Expired)
assert.NotNil(c, peerStatus.KeyExpiry)
t.Logf(
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
now.String(),
peerStatus.KeyExpiry,
)
if peerStatus.KeyExpiry != nil {
assert.Truef(
t,
c,
peerStatus.KeyExpiry.Before(now),
"node %q should have a key expire before %s, was %s",
peerStatus.HostName,
@@ -797,7 +901,7 @@ func TestExpireNode(t *testing.T) {
}
assert.Truef(
t,
c,
peerStatus.Expired,
"node %q should be expired, expired is %v",
peerStatus.HostName,
@@ -806,24 +910,112 @@ func TestExpireNode(t *testing.T) {
_, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()})
if !strings.Contains(stderr, "node key has expired") {
t.Errorf(
c.Errorf(
"expected to be unable to ping expired host %q from %q",
node.GetName(),
client.Hostname(),
)
}
} else {
t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey)
}
} else {
if status.Self.KeyExpiry != nil {
assert.Truef(t, status.Self.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", status.Self.HostName, now.String(), status.Self.KeyExpiry)
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for expired node status to propagate")
}
}
// NeedsLogin means that the node has understood that it is no longer
// valid.
assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName)
// TestSetNodeExpiryInFuture tests setting arbitrary expiration date
// New expiration date should be stored in the db and propagated to all peers
func TestSetNodeExpiryInFuture(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenodefuture"))
requireNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
targetExpiry := time.Now().Add(2 * time.Hour).Round(time.Second).UTC()
result, err := headscale.Execute(
[]string{
"headscale", "nodes", "expire",
"--identifier", "1",
"--output", "json",
"--expiry", targetExpiry.Format(time.RFC3339),
},
)
require.NoError(t, err)
var node v1.Node
err = json.Unmarshal([]byte(result), &node)
require.NoError(t, err)
require.True(t, node.GetExpiry().AsTime().After(time.Now()))
require.WithinDuration(t, targetExpiry, node.GetExpiry().AsTime(), 2*time.Second)
var nodeKey key.NodePublic
err = nodeKey.UnmarshalText([]byte(node.GetNodeKey()))
require.NoError(t, err)
for _, client := range allClients {
if client.Hostname() == node.GetName() {
continue
}
assert.EventuallyWithT(
t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err)
peerStatus, ok := status.Peer[nodeKey]
assert.True(ct, ok, "node key should be present in peer list")
if !ok {
return
}
assert.NotNil(ct, peerStatus.KeyExpiry)
assert.NotNil(ct, peerStatus.Expired)
if peerStatus.KeyExpiry != nil {
assert.WithinDuration(
ct,
targetExpiry,
*peerStatus.KeyExpiry,
5*time.Second,
"node %q should have key expiry near the requested future time",
peerStatus.HostName,
)
assert.Truef(
ct,
peerStatus.KeyExpiry.After(time.Now()),
"node %q should have a key expiry timestamp in the future",
peerStatus.HostName,
)
}
assert.Falsef(
ct,
peerStatus.Expired,
"node %q should not be marked as expired",
peerStatus.HostName,
)
}, 3*time.Minute, 5*time.Second, "Waiting for future expiry to propagate",
)
}
}
@@ -861,11 +1053,13 @@ func TestNodeOnlineStatus(t *testing.T) {
t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps))
for _, client := range allClients {
status, err := client.Status()
require.NoError(t, err)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
// Assert that we have the original count - self
assert.Len(c, status.Peers(), len(MustTestVersions)-1)
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer count")
}
headscale, err := scenario.Headscale()

View File

@@ -507,6 +507,11 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node.GetLastSeen())
}
func assertLastSeenSetWithCollect(c *assert.CollectT, node *v1.Node) {
assert.NotNil(c, node)
assert.NotNil(c, node.GetLastSeen())
}
// assertTailscaleNodesLogout verifies that all provided Tailscale clients
// are in the logged-out state (NeedsLogin).
func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
@@ -633,50 +638,50 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Logf("Checking netmap of %q", client.Hostname())
netmap, err := client.Netmap()
if err != nil {
t.Fatalf("getting netmap for %q: %s", client.Hostname(), err)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
netmap, err := client.Netmap()
assert.NoError(c, err, "getting netmap for %q", client.Hostname())
assert.Truef(t, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(t, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(t, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
assert.NotEqualf(t, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(t, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(t, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
// assert.Truef(t, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(t, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
assert.NotEmptyf(t, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(t, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname())
assert.Falsef(t, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(t, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname())
for _, peer := range netmap.Peers {
assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString())
assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
// Netinfo is not always set
// assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname())
if ni := hi.NetInfo(); ni.Valid() {
assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP())
}
}
assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname())
assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname())
assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname())
assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname())
assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname())
assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname())
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname())
}
// assertValidStatus validates that a client's status has all required fields for proper operation.
@@ -920,3 +925,125 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
EmailVerified: emailVerified,
}
}
// GetUserByName retrieves a user by name from the headscale server.
// This is a common pattern used when creating preauth keys or managing users.
func GetUserByName(headscale ControlServer, username string) (*v1.User, error) {
users, err := headscale.ListUsers()
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
for _, u := range users {
if u.GetName() == username {
return u, nil
}
}
return nil, fmt.Errorf("user %s not found", username)
}
// FindNewClient finds a client that is in the new list but not in the original list.
// This is useful when dynamically adding nodes during tests and needing to identify
// which client was just added.
func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) {
for _, client := range updated {
isOriginal := false
for _, origClient := range original {
if client.Hostname() == origClient.Hostname() {
isOriginal = true
break
}
}
if !isOriginal {
return client, nil
}
}
return nil, fmt.Errorf("no new client found")
}
// AddAndLoginClient adds a new tailscale client to a user and logs it in.
// This combines the common pattern of:
// 1. Creating a new node
// 2. Finding the new node in the client list
// 3. Getting the user to create a preauth key
// 4. Logging in the new node
func (s *Scenario) AddAndLoginClient(
t *testing.T,
username string,
version string,
headscale ControlServer,
tsOpts ...tsic.Option,
) (TailscaleClient, error) {
t.Helper()
// Get the original client list
originalClients, err := s.ListTailscaleClients(username)
if err != nil {
return nil, fmt.Errorf("failed to list original clients: %w", err)
}
// Create the new node
err = s.CreateTailscaleNodesInUser(username, version, 1, tsOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create tailscale node: %w", err)
}
// Wait for the new node to appear in the client list
var newClient TailscaleClient
_, err = backoff.Retry(t.Context(), func() (struct{}, error) {
updatedClients, err := s.ListTailscaleClients(username)
if err != nil {
return struct{}{}, fmt.Errorf("failed to list updated clients: %w", err)
}
if len(updatedClients) != len(originalClients)+1 {
return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients))
}
newClient, err = FindNewClient(originalClients, updatedClients)
if err != nil {
return struct{}{}, fmt.Errorf("failed to find new client: %w", err)
}
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second))
if err != nil {
return nil, fmt.Errorf("timeout waiting for new client: %w", err)
}
// Get the user and create preauth key
user, err := GetUserByName(headscale, username)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
authKey, err := s.CreatePreAuthKey(user.GetId(), true, false)
if err != nil {
return nil, fmt.Errorf("failed to create preauth key: %w", err)
}
// Login the new client
err = newClient.Login(headscale.GetEndpoint(), authKey.GetKey())
if err != nil {
return nil, fmt.Errorf("failed to login new client: %w", err)
}
return newClient, nil
}
// MustAddAndLoginClient is like AddAndLoginClient but fails the test on error.
func (s *Scenario) MustAddAndLoginClient(
t *testing.T,
username string,
version string,
headscale ControlServer,
tsOpts ...tsic.Option,
) TailscaleClient {
t.Helper()
client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...)
require.NoError(t, err)
return client
}

View File

@@ -1082,6 +1082,30 @@ func (t *HeadscaleInContainer) ListNodes(
return ret, nil
}
func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error {
command := []string{
"headscale",
"nodes",
"delete",
"--identifier",
fmt.Sprintf("%d", nodeID),
"--output",
"json",
"--force",
}
_, _, err := dockertestutil.ExecuteCommand(
t.container,
command,
[]string{},
)
if err != nil {
return fmt.Errorf("failed to execute delete node command: %w", err)
}
return nil
}
func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
nodes, err := t.ListNodes()
if err != nil {
@@ -1208,26 +1232,26 @@ func (h *HeadscaleInContainer) writePolicy(pol *policyv2.Policy) error {
}
func (h *HeadscaleInContainer) PID() (int, error) {
cmd := []string{"bash", "-c", `ps aux | grep headscale | grep -v grep | awk '{print $2}'`}
output, err := h.Execute(cmd)
// Use pidof to find the headscale process, which is more reliable than grep
// as it only looks for the actual binary name, not processes that contain
// "headscale" in their command line (like the dlv debugger).
output, err := h.Execute([]string{"pidof", "headscale"})
if err != nil {
return 0, fmt.Errorf("failed to execute command: %w", err)
// pidof returns exit code 1 when no process is found
return 0, os.ErrNotExist
}
lines := strings.TrimSpace(output)
if lines == "" {
return 0, os.ErrNotExist // No output means no process found
// pidof returns space-separated PIDs on a single line
pidStrs := strings.Fields(strings.TrimSpace(output))
if len(pidStrs) == 0 {
return 0, os.ErrNotExist
}
pids := make([]int, 0, len(lines))
for _, line := range strings.Split(lines, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
pidInt, err := strconv.Atoi(line)
pids := make([]int, 0, len(pidStrs))
for _, pidStr := range pidStrs {
pidInt, err := strconv.Atoi(pidStr)
if err != nil {
return 0, fmt.Errorf("parsing PID: %w", err)
return 0, fmt.Errorf("parsing PID %q: %w", pidStr, err)
}
// We dont care about the root pid for the container
if pidInt == 1 {
@@ -1242,7 +1266,9 @@ func (h *HeadscaleInContainer) PID() (int, error) {
case 1:
return pids[0], nil
default:
return 0, errors.New("multiple headscale processes running")
// If we still have multiple PIDs, return the first one as a fallback
// This can happen in edge cases during startup/shutdown
return pids[0], nil
}
}
@@ -1397,3 +1423,38 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er
return nodeStore, nil
}
// DebugFilter fetches the current filter rules from the debug endpoint.
func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) {
// Execute curl inside the container to access the debug endpoint locally
command := []string{
"curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/filter",
}
result, err := t.Execute(command)
if err != nil {
return nil, fmt.Errorf("fetching filter from debug endpoint: %w", err)
}
var filterRules []tailcfg.FilterRule
if err := json.Unmarshal([]byte(result), &filterRules); err != nil {
return nil, fmt.Errorf("decoding filter response: %w", err)
}
return filterRules, nil
}
// DebugPolicy fetches the current policy from the debug endpoint.
func (t *HeadscaleInContainer) DebugPolicy() (string, error) {
// Execute curl inside the container to access the debug endpoint locally
command := []string{
"curl", "-s", "http://localhost:9090/debug/policy",
}
result, err := t.Execute(command)
if err != nil {
return "", fmt.Errorf("fetching policy from debug endpoint: %w", err)
}
return result, nil
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
xmaps "golang.org/x/exp/maps"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@@ -1358,16 +1359,8 @@ func TestSubnetRouteACL(t *testing.T) {
// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}
statusJ, err := allClients[j].Status()
if err != nil {
return false
}
statusI := allClients[i].MustStatus()
statusJ := allClients[j].MustStatus()
return statusI.Self.ID < statusJ.Self.ID
})
@@ -1475,9 +1468,7 @@ func TestSubnetRouteACL(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])})
}, 5*time.Second, 200*time.Millisecond, "Verifying client can see subnet routes from router")
clientNm, err := client.Netmap()
require.NoError(t, err)
// Wait for packet filter updates to propagate to client netmap
wantClientFilter := []filter.Match{
{
IPProto: views.SliceOf([]ipproto.Proto{
@@ -1503,13 +1494,16 @@ func TestSubnetRouteACL(t *testing.T) {
},
}
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
clientNm, err := client.Netmap()
assert.NoError(c, err)
subnetNm, err := subRouter1.Netmap()
require.NoError(t, err)
if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
assert.Fail(c, fmt.Sprintf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff))
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for client packet filter to update")
// Wait for packet filter updates to propagate to subnet router netmap
wantSubnetFilter := []filter.Match{
{
IPProto: views.SliceOf([]ipproto.Proto{
@@ -1553,9 +1547,14 @@ func TestSubnetRouteACL(t *testing.T) {
},
}
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
subnetNm, err := subRouter1.Netmap()
assert.NoError(c, err)
if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
assert.Fail(c, fmt.Sprintf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff))
}
}, 10*time.Second, 200*time.Millisecond, "Waiting for subnet router packet filter to update")
}
// TestEnablingExitRoutes tests enabling exit routes for clients.
@@ -1592,12 +1591,16 @@ func TestEnablingExitRoutes(t *testing.T) {
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
nodes, err := headscale.ListNodes()
require.NoError(t, err)
require.Len(t, nodes, 2)
var nodes []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
requireNodeRouteCount(t, nodes[0], 2, 0, 0)
requireNodeRouteCount(t, nodes[1], 2, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[0], 2, 0, 0)
requireNodeRouteCountWithCollect(c, nodes[1], 2, 0, 0)
}, 10*time.Second, 200*time.Millisecond, "Waiting for route advertisements to propagate")
// Verify that no routes has been sent to the client,
// they are not yet enabled.
@@ -2213,11 +2216,31 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
},
}
// Check if we should run the full matrix of tests
// By default, we only run a minimal subset to avoid overwhelming Docker/disk
// Set HEADSCALE_INTEGRATION_FULL_MATRIX=1 to run all combinations
fullMatrix := envknob.Bool("HEADSCALE_INTEGRATION_FULL_MATRIX")
// Minimal test set: 3 tests covering all key dimensions
// - Both auth methods (authkey, webauth)
// - All 3 approver types (tag, user, group)
// - Both policy modes (database, file)
// - Both advertiseDuringUp values (true, false)
minimalTestSet := map[string]bool{
"authkey-tag-advertiseduringup-false-pol-database": true, // authkey + database + tag + false
"webauth-user-advertiseduringup-true-pol-file": true, // webauth + file + user + true
"authkey-group-advertiseduringup-false-pol-file": true, // authkey + file + group + false
}
for _, tt := range tests {
for _, polMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
for _, advertiseDuringUp := range []bool{false, true} {
name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, polMode)
t.Run(name, func(t *testing.T) {
// Skip tests not in minimal set unless full matrix is enabled
if !fullMatrix && !minimalTestSet[name] {
t.Skip("Skipping to reduce test matrix size. Set HEADSCALE_INTEGRATION_FULL_MATRIX=1 to run all tests.")
}
scenario, err := NewScenario(tt.spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)

View File

@@ -693,6 +693,35 @@ func (s *Scenario) WaitForTailscaleSync() error {
return err
}
// WaitForTailscaleSyncPerUser blocks execution until each TailscaleClient has the expected
// number of peers for its user. This is useful for policies like autogroup:self where nodes
// only see same-user peers, not all nodes in the network.
func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Duration) error {
var allErrors []error
for _, user := range s.users {
// Calculate expected peer count: number of nodes in this user minus 1 (self)
expectedPeers := len(user.Clients) - 1
for _, client := range user.Clients {
c := client
expectedCount := expectedPeers
user.syncWaitGroup.Go(func() error {
return c.WaitForPeers(expectedCount, timeout, retryInterval)
})
}
if err := user.syncWaitGroup.Wait(); err != nil {
allErrors = append(allErrors, err)
}
}
if len(allErrors) > 0 {
return multierr.New(allErrors...)
}
return nil
}
// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports
// to have all other TailscaleClients present in their netmap.NetworkMap.
func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, retryInterval time.Duration) error {

View File

@@ -14,6 +14,7 @@ import (
"tailscale.com/net/netcheck"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/wgengine/filter"
)
// nolint
@@ -36,6 +37,7 @@ type TailscaleClient interface {
MustIPv4() netip.Addr
MustIPv6() netip.Addr
FQDN() (string, error)
MustFQDN() string
Status(...bool) (*ipnstate.Status, error)
MustStatus() *ipnstate.Status
Netmap() (*netmap.NetworkMap, error)
@@ -52,6 +54,7 @@ type TailscaleClient interface {
ContainerID() string
MustID() types.NodeID
ReadFile(path string) ([]byte, error)
PacketFilter() ([]filter.Match, error)
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client
// and a bool indicating if the clients online count and peer count is equal.

View File

@@ -18,6 +18,7 @@ import (
"strings"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
@@ -32,6 +33,7 @@ import (
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/multierr"
"tailscale.com/wgengine/filter"
)
const (
@@ -597,28 +599,39 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
return t.ips, nil
}
ips := make([]netip.Addr, 0)
command := []string{
"tailscale",
"ip",
}
result, _, err := t.Execute(command)
if err != nil {
return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err)
}
for address := range strings.SplitSeq(result, "\n") {
address = strings.TrimSuffix(address, "\n")
if len(address) < 1 {
continue
// Retry with exponential backoff to handle eventual consistency
ips, err := backoff.Retry(context.Background(), func() ([]netip.Addr, error) {
command := []string{
"tailscale",
"ip",
}
ip, err := netip.ParseAddr(address)
result, _, err := t.Execute(command)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s failed to get IPs: %w", t.hostname, err)
}
ips = append(ips, ip)
ips := make([]netip.Addr, 0)
for address := range strings.SplitSeq(result, "\n") {
address = strings.TrimSuffix(address, "\n")
if len(address) < 1 {
continue
}
ip, err := netip.ParseAddr(address)
if err != nil {
return nil, fmt.Errorf("failed to parse IP %s: %w", address, err)
}
ips = append(ips, ip)
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname)
}
return ips, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
if err != nil {
return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err)
}
return ips, nil
@@ -629,7 +642,6 @@ func (t *TailscaleInContainer) MustIPs() []netip.Addr {
if err != nil {
panic(err)
}
return ips
}
@@ -646,16 +658,15 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) {
}
}
return netip.Addr{}, errors.New("no IPv4 address found")
return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname)
}
func (t *TailscaleInContainer) MustIPv4() netip.Addr {
for _, ip := range t.MustIPs() {
if ip.Is4() {
return ip
}
ip, err := t.IPv4()
if err != nil {
panic(err)
}
panic("no ipv4 found")
return ip
}
func (t *TailscaleInContainer) MustIPv6() netip.Addr {
@@ -900,12 +911,33 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
return t.fqdn, nil
}
status, err := t.Status()
// Retry with exponential backoff to handle eventual consistency
fqdn, err := backoff.Retry(context.Background(), func() (string, error) {
status, err := t.Status()
if err != nil {
return "", fmt.Errorf("failed to get status: %w", err)
}
if status.Self.DNSName == "" {
return "", fmt.Errorf("FQDN not yet available")
}
return status.Self.DNSName, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
if err != nil {
return "", fmt.Errorf("failed to get FQDN: %w", err)
return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err)
}
return status.Self.DNSName, nil
return fqdn, nil
}
// MustFQDN returns the FQDN as a string of the Tailscale instance, panicking on error.
func (t *TailscaleInContainer) MustFQDN() string {
fqdn, err := t.FQDN()
if err != nil {
panic(err)
}
return fqdn
}
// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client
@@ -1353,3 +1385,18 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
return &p.Persist.PrivateNodeKey, nil
}
// PacketFilter returns the current packet filter rules from the client's network map.
// This is useful for verifying that policy changes have propagated to the client.
func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) {
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version)
}
nm, err := t.Netmap()
if err != nil {
return nil, fmt.Errorf("failed to get netmap: %w", err)
}
return nm.PacketFilter, nil
}

View File

@@ -82,7 +82,10 @@ message DeleteNodeRequest { uint64 node_id = 1; }
message DeleteNodeResponse {}
message ExpireNodeRequest { uint64 node_id = 1; }
message ExpireNodeRequest {
uint64 node_id = 1;
google.protobuf.Timestamp expiry = 2;
}
message ExpireNodeResponse { Node node = 1; }

View File

@@ -136,7 +136,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
}
// Write to file
err = os.WriteFile(outputFile, formatted, 0644)
err = os.WriteFile(outputFile, formatted, 0o644)
if err != nil {
return fmt.Errorf("error writing file: %w", err)
}