mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-13 19:27:40 +01:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0d8c3c877 | ||
|
|
c1b468f9f4 | ||
|
|
900f4b7b75 | ||
|
|
64f23136a2 | ||
|
|
0f6d312ada | ||
|
|
20dff82f95 | ||
|
|
31c4331a91 | ||
|
|
ce580f8245 | ||
|
|
bfb6fd80df | ||
|
|
3acce2da87 | ||
|
|
4a9a329339 | ||
|
|
dd16567c52 | ||
|
|
e0a436cefc | ||
|
|
53cdeff129 | ||
|
|
7148a690d0 | ||
|
|
4e73133b9f | ||
|
|
4f8724151e | ||
|
|
91730e2a1d | ||
|
|
b5090a01ec | ||
|
|
27f5641341 | ||
|
|
cf3d30b6f6 | ||
|
|
58020696fe | ||
|
|
e44b402fe4 | ||
|
|
835b7eb960 | ||
|
|
95b1fd636e | ||
|
|
834ac27779 | ||
|
|
4a4032a4b0 | ||
|
|
29aa08df0e | ||
|
|
0b1727c337 | ||
|
|
08fe2e4d6c | ||
|
|
cb29cade46 | ||
|
|
f27298c759 | ||
|
|
8baa14ef4a | ||
|
|
ebdbe03639 | ||
|
|
f735502eae | ||
|
|
53d17aa321 | ||
|
|
14f833bdb9 | ||
|
|
9e50071df9 | ||
|
|
c907b0d323 |
6
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
6
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
@@ -6,8 +6,7 @@ body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is this a support request?
|
||||
description:
|
||||
This issue tracker is for bugs and feature requests only. If you need
|
||||
description: This issue tracker is for bugs and feature requests only. If you need
|
||||
help, please use ask in our Discord community
|
||||
options:
|
||||
- label: This is not a support request
|
||||
@@ -15,8 +14,7 @@ body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description:
|
||||
Please search to see if an issue already exists for the bug you
|
||||
description: Please search to see if an issue already exists for the bug you
|
||||
encountered.
|
||||
options:
|
||||
- label: I have searched the existing issues
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
8
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -3,9 +3,9 @@ blank_issues_enabled: false
|
||||
|
||||
# Contact links
|
||||
contact_links:
|
||||
- name: "headscale usage documentation"
|
||||
url: "https://github.com/juanfont/headscale/blob/main/docs"
|
||||
about: "Find documentation about how to configure and run headscale."
|
||||
- name: "headscale Discord community"
|
||||
url: "https://discord.gg/xGj2TuqyxY"
|
||||
url: "https://discord.gg/c84AZQhmpx"
|
||||
about: "Please ask and answer questions about usage of headscale here."
|
||||
- name: "headscale usage documentation"
|
||||
url: "https://headscale.net/"
|
||||
about: "Find documentation about how to configure and run headscale."
|
||||
|
||||
80
.github/label-response/needs-more-info.md
vendored
Normal file
80
.github/label-response/needs-more-info.md
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
Thank you for taking the time to report this issue.
|
||||
|
||||
To help us investigate and resolve this, we need more information. Please provide the following:
|
||||
|
||||
> [!TIP]
|
||||
> Most issues turn out to be configuration errors rather than bugs. We encourage you to discuss your problem in our [Discord community](https://discord.gg/c84AZQhmpx) **before** opening an issue. The community can often help identify misconfigurations quickly, saving everyone time.
|
||||
|
||||
## Required Information
|
||||
|
||||
### Environment Details
|
||||
|
||||
- **Headscale version**: (run `headscale version`)
|
||||
- **Tailscale client version**: (run `tailscale version`)
|
||||
- **Operating System**: (e.g., Ubuntu 24.04, macOS 14, Windows 11)
|
||||
- **Deployment method**: (binary, Docker, Kubernetes, etc.)
|
||||
- **Reverse proxy**: (if applicable: nginx, Traefik, Caddy, etc. - include configuration)
|
||||
|
||||
### Debug Information
|
||||
|
||||
Please follow our [Debugging and Troubleshooting Guide](https://headscale.net/stable/ref/debug/) and provide:
|
||||
|
||||
1. **Client netmap dump** (from affected Tailscale client):
|
||||
|
||||
```bash
|
||||
tailscale debug netmap > netmap.json
|
||||
```
|
||||
|
||||
2. **Client status dump** (from affected Tailscale client):
|
||||
|
||||
```bash
|
||||
tailscale status --json > status.json
|
||||
```
|
||||
|
||||
3. **Tailscale client logs** (if experiencing client issues):
|
||||
|
||||
```bash
|
||||
tailscale debug daemon-logs
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> We need logs from **multiple nodes** to understand the full picture:
|
||||
>
|
||||
> - The node(s) initiating connections
|
||||
> - The node(s) being connected to
|
||||
>
|
||||
> Without logs from both sides, we cannot diagnose connectivity issues.
|
||||
|
||||
4. **Headscale server logs** with `log.level: trace` enabled
|
||||
|
||||
5. **Headscale configuration** (with sensitive values redacted - see rules below)
|
||||
|
||||
6. **ACL/Policy configuration** (if using ACLs)
|
||||
|
||||
7. **Proxy/Docker configuration** (if applicable - nginx.conf, docker-compose.yml, Traefik config, etc.)
|
||||
|
||||
## Formatting Requirements
|
||||
|
||||
- **Attach long files** - Do not paste large logs or configurations inline. Use GitHub file attachments or GitHub Gists.
|
||||
- **Use proper Markdown** - Format code blocks, logs, and configurations with appropriate syntax highlighting.
|
||||
- **Structure your response** - Use the headings above to organize your information clearly.
|
||||
|
||||
## Redaction Rules
|
||||
|
||||
> [!CAUTION]
|
||||
> **Replace, do not remove.** Removing information makes debugging impossible.
|
||||
|
||||
When redacting sensitive information:
|
||||
|
||||
- ✅ **Replace consistently** - If you change `alice@company.com` to `user1@example.com`, use `user1@example.com` everywhere (logs, config, policy, etc.)
|
||||
- ✅ **Use meaningful placeholders** - `user1@example.com`, `bob@example.com`, `my-secret-key` are acceptable
|
||||
- ❌ **Never remove information** - Gaps in data prevent us from correlating events across logs
|
||||
- ❌ **Never redact IP addresses** - We need the actual IPs to trace network paths and identify issues
|
||||
|
||||
**If redaction rules are not followed, we will be unable to debug the issue and will have to close it.**
|
||||
|
||||
---
|
||||
|
||||
**Note:** This issue will be automatically closed in 3 days if no additional information is provided. Once you reply with the requested information, the `needs-more-info` label will be removed automatically.
|
||||
|
||||
If you need help gathering this information, please visit our [Discord community](https://discord.gg/c84AZQhmpx).
|
||||
15
.github/label-response/support-request.md
vendored
Normal file
15
.github/label-response/support-request.md
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
Thank you for reaching out.
|
||||
|
||||
This issue tracker is used for **bug reports and feature requests** only. Your question appears to be a support or configuration question rather than a bug report.
|
||||
|
||||
For help with setup, configuration, or general questions, please visit our [Discord community](https://discord.gg/c84AZQhmpx) where the community and maintainers can assist you in real-time.
|
||||
|
||||
**Before posting in Discord, please check:**
|
||||
|
||||
- [Documentation](https://headscale.net/)
|
||||
- [FAQ](https://headscale.net/stable/faq/)
|
||||
- [Debugging and Troubleshooting Guide](https://headscale.net/stable/ref/debug/)
|
||||
|
||||
If after troubleshooting you determine this is actually a bug, please open a new issue with the required debug information from the troubleshooting guide.
|
||||
|
||||
This issue has been automatically closed.
|
||||
28
.github/workflows/needs-more-info-comment.yml
vendored
Normal file
28
.github/workflows/needs-more-info-comment.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Needs More Info - Post Comment
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
post-comment:
|
||||
if: >-
|
||||
github.event.label.name == 'needs-more-info' &&
|
||||
github.repository == 'juanfont/headscale'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
sparse-checkout: .github/label-response/needs-more-info.md
|
||||
sparse-checkout-cone-mode: false
|
||||
|
||||
- name: Post instruction comment
|
||||
run: gh issue comment "$NUMBER" --body-file .github/label-response/needs-more-info.md
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_REPO: ${{ github.repository }}
|
||||
NUMBER: ${{ github.event.issue.number }}
|
||||
31
.github/workflows/needs-more-info-timer.yml
vendored
Normal file
31
.github/workflows/needs-more-info-timer.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name: Needs More Info - Timer
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # Daily at midnight UTC
|
||||
issue_comment:
|
||||
types: [created]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
manage-needs-more-info:
|
||||
if: >-
|
||||
github.repository == 'juanfont/headscale' &&
|
||||
(github.event_name != 'issue_comment' || github.event.comment.user.type != 'Bot')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- name: Manage needs-more-info issues
|
||||
uses: tiangolo/issue-manager@2fb3484ec9279485df8659e8ec73de262431737d # v0.6.0
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
config: >
|
||||
{
|
||||
"needs-more-info": {
|
||||
"delay": "P3D",
|
||||
"message": "This issue has been automatically closed because no additional information was provided within 3 days.\n\nIf you now have the requested information, please feel free to reopen this issue and provide the details. We're happy to help once we have enough context to investigate.\n\nThank you for your understanding.",
|
||||
"remove_label_on_comment": true,
|
||||
"remove_label_on_close": true
|
||||
}
|
||||
}
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -23,5 +23,5 @@ jobs:
|
||||
since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
exempt-issue-labels: "no-stale-bot"
|
||||
exempt-issue-labels: "no-stale-bot,needs-more-info"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
30
.github/workflows/support-request.yml
vendored
Normal file
30
.github/workflows/support-request.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Support Request - Close Issue
|
||||
|
||||
on:
|
||||
issues:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
close-support-request:
|
||||
if: >-
|
||||
github.event.label.name == 'support-request' &&
|
||||
github.repository == 'juanfont/headscale'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
sparse-checkout: .github/label-response/support-request.md
|
||||
sparse-checkout-cone-mode: false
|
||||
|
||||
- name: Post comment and close issue
|
||||
run: |
|
||||
gh issue comment "$NUMBER" --body-file .github/label-response/support-request.md
|
||||
gh issue close "$NUMBER" --reason "not planned"
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_REPO: ${{ github.repository }}
|
||||
NUMBER: ${{ github.event.issue.number }}
|
||||
@@ -18,6 +18,7 @@ linters:
|
||||
- lll
|
||||
- maintidx
|
||||
- makezero
|
||||
- mnd
|
||||
- musttag
|
||||
- nestif
|
||||
- nolintlint
|
||||
@@ -37,6 +38,23 @@ linters:
|
||||
time.Sleep is forbidden.
|
||||
In tests: use assert.EventuallyWithT for polling/waiting patterns.
|
||||
In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives.
|
||||
# Forbid inline string literals in zerolog field methods - use zf.* constants
|
||||
- pattern: '\.(Str|Int|Int8|Int16|Int32|Int64|Uint|Uint8|Uint16|Uint32|Uint64|Float32|Float64|Bool|Dur|Time|TimeDiff|Strs|Ints|Uints|Floats|Bools|Any|Interface)\("[^"]+"'
|
||||
msg: >-
|
||||
Use zf.* constants for zerolog field names instead of string literals.
|
||||
Import "github.com/juanfont/headscale/hscontrol/util/zlog/zf" and use
|
||||
constants like zf.NodeID, zf.UserName, etc. Add new constants to
|
||||
hscontrol/util/zlog/zf/fields.go if needed.
|
||||
# Forbid ptr.To - use Go 1.26 new(expr) instead
|
||||
- pattern: 'ptr\.To\('
|
||||
msg: >-
|
||||
ptr.To is forbidden. Use Go 1.26's new(expr) syntax instead.
|
||||
Example: ptr.To(value) → new(value)
|
||||
# Forbid tsaddr.SortPrefixes - use slices.SortFunc with netip.Prefix.Compare
|
||||
- pattern: 'tsaddr\.SortPrefixes'
|
||||
msg: >-
|
||||
tsaddr.SortPrefixes is forbidden. Use Go 1.26's netip.Prefix.Compare instead.
|
||||
Example: slices.SortFunc(prefixes, netip.Prefix.Compare)
|
||||
analyze-types: true
|
||||
gocritic:
|
||||
disabled-checks:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
version: 2
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy -compat=1.25
|
||||
- go mod tidy -compat=1.26
|
||||
- go mod vendor
|
||||
|
||||
release:
|
||||
|
||||
@@ -43,26 +43,12 @@ repos:
|
||||
entry: prettier --write --list-different
|
||||
language: system
|
||||
exclude: ^docs/
|
||||
types_or:
|
||||
[
|
||||
javascript,
|
||||
jsx,
|
||||
ts,
|
||||
tsx,
|
||||
yaml,
|
||||
json,
|
||||
toml,
|
||||
html,
|
||||
css,
|
||||
scss,
|
||||
sass,
|
||||
markdown,
|
||||
]
|
||||
types_or: [javascript, jsx, ts, tsx, yaml, json, toml, html, css, scss, sass, markdown]
|
||||
|
||||
# golangci-lint for Go code quality
|
||||
- id: golangci-lint
|
||||
name: golangci-lint
|
||||
entry: nix develop --command golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
|
||||
entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
|
||||
language: system
|
||||
types: [go]
|
||||
pass_filenames: false
|
||||
|
||||
28
CHANGELOG.md
28
CHANGELOG.md
@@ -2,6 +2,34 @@
|
||||
|
||||
## 0.29.0 (202x-xx-xx)
|
||||
|
||||
**Minimum supported Tailscale client version: v1.76.0**
|
||||
|
||||
### Tailscale ACL compatibility improvements
|
||||
|
||||
Extensive test cases were systematically generated using Tailscale clients and the official SaaS
|
||||
to understand how the packet filter should be generated. We discovered a few differences, but
|
||||
overall our implementation was very close.
|
||||
[#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
|
||||
### BREAKING
|
||||
|
||||
- **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
- This better matches Tailscale's security model where `*` means "any node in the tailnet" rather than "any IP address"
|
||||
- Policies relying on wildcard to match non-Tailscale IPs will need to use explicit CIDR ranges instead
|
||||
- **Note**: Users with non-standard IP ranges configured in `prefixes.ipv4` or `prefixes.ipv6` (which is unsupported and produces a warning) will need to explicitly specify their CIDR ranges in ACL rules instead of using `*`
|
||||
- **ACL Policy**: Validate autogroup:self source restrictions matching Tailscale behavior - tags, hosts, and IPs are rejected as sources for autogroup:self destinations [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
- Policies using tags, hosts, or IP addresses as sources for autogroup:self destinations will now fail validation
|
||||
- **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
- Previously, `proto:icmp` included both ICMPv4 and ICMPv6
|
||||
- Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6
|
||||
|
||||
### Changes
|
||||
|
||||
- **ACL Policy**: Add ICMP and IPv6-ICMP protocols to default filter rules when no protocol is specified [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
- **ACL Policy**: Fix autogroup:self handling for tagged nodes - tagged nodes no longer incorrectly receive autogroup:self filter rules [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
- **ACL Policy**: Use CIDR format for autogroup:self destination IPs matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
- **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||
|
||||
## 0.28.0 (2026-02-04)
|
||||
|
||||
**Minimum supported Tailscale client version: v1.74.0**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# For testing purposes only
|
||||
|
||||
FROM golang:alpine AS build-env
|
||||
FROM golang:1.26rc2-alpine AS build-env
|
||||
|
||||
WORKDIR /go/src
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# and are in no way endorsed by Headscale's maintainers as an
|
||||
# official nor supported release or distribution.
|
||||
|
||||
FROM docker.io/golang:1.25-trixie AS builder
|
||||
FROM docker.io/golang:1.26rc2-trixie AS builder
|
||||
ARG VERSION=dev
|
||||
ENV GOPATH /go
|
||||
WORKDIR /go/src/headscale
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# This Dockerfile is more or less lifted from tailscale/tailscale
|
||||
# to ensure a similar build process when testing the HEAD of tailscale.
|
||||
|
||||
FROM golang:1.25-alpine AS build-env
|
||||
FROM golang:1.26rc2-alpine AS build-env
|
||||
|
||||
WORKDIR /go/src
|
||||
|
||||
|
||||
@@ -67,6 +67,8 @@ For NixOS users, a module is available in [`nix/`](./nix/).
|
||||
|
||||
## Talks
|
||||
|
||||
- Fosdem 2026 (video): [Headscale & Tailscale: The complementary open source clone](https://fosdem.org/2026/schedule/event/KYQ3LL-headscale-the-complementary-open-source-clone/)
|
||||
- presented by Kristoffer Dalby
|
||||
- Fosdem 2023 (video): [Headscale: How we are using integration testing to reimplement Tailscale](https://fosdem.org/2023/schedule/event/goheadscale/)
|
||||
- presented by Juan Font Alonso and Kristoffer Dalby
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// 90 days.
|
||||
// DefaultAPIKeyExpiry is 90 days.
|
||||
DefaultAPIKeyExpiry = "90d"
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ var configTestCmd = &cobra.Command{
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
_, err := newHeadscaleServerWithConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msg("Error initializing")
|
||||
log.Fatal().Caller().Err(err).Msg("error initializing")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -19,10 +19,12 @@ func init() {
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
createNodeCmd.Flags().StringP("name", "", "", "Name")
|
||||
|
||||
err := createNodeCmd.MarkFlagRequired("name")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
createNodeCmd.Flags().StringP("user", "u", "", "User")
|
||||
|
||||
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
||||
@@ -34,11 +36,14 @@ func init() {
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
createNodeCmd.Flags().StringP("key", "k", "", "Key")
|
||||
|
||||
err = createNodeCmd.MarkFlagRequired("key")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
createNodeCmd.Flags().
|
||||
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/oauth2-proxy/mockoidc"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -19,6 +20,7 @@ const (
|
||||
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
||||
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
||||
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
|
||||
errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined")
|
||||
refreshTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
@@ -35,7 +37,7 @@ var mockOidcCmd = &cobra.Command{
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
err := mockOIDC()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error running mock OIDC server")
|
||||
log.Error().Err(err).Msgf("error running mock OIDC server")
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
@@ -46,41 +48,47 @@ func mockOIDC() error {
|
||||
if clientID == "" {
|
||||
return errMockOidcClientIDNotDefined
|
||||
}
|
||||
|
||||
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
|
||||
if clientSecret == "" {
|
||||
return errMockOidcClientSecretNotDefined
|
||||
}
|
||||
|
||||
addrStr := os.Getenv("MOCKOIDC_ADDR")
|
||||
if addrStr == "" {
|
||||
return errMockOidcPortNotDefined
|
||||
}
|
||||
|
||||
portStr := os.Getenv("MOCKOIDC_PORT")
|
||||
if portStr == "" {
|
||||
return errMockOidcPortNotDefined
|
||||
}
|
||||
|
||||
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
|
||||
if accessTTLOverride != "" {
|
||||
newTTL, err := time.ParseDuration(accessTTLOverride)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accessTTL = newTTL
|
||||
}
|
||||
|
||||
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||
if userStr == "" {
|
||||
return errors.New("MOCKOIDC_USERS not defined")
|
||||
return errMockOidcUsersNotDefined
|
||||
}
|
||||
|
||||
var users []mockoidc.MockUser
|
||||
|
||||
err := json.Unmarshal([]byte(userStr), &users)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshalling users: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Interface("users", users).Msg("loading users from JSON")
|
||||
log.Info().Interface(zf.Users, users).Msg("loading users from JSON")
|
||||
|
||||
log.Info().Msgf("Access token TTL: %s", accessTTL)
|
||||
log.Info().Msgf("access token TTL: %s", accessTTL)
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
@@ -92,7 +100,7 @@ func mockOIDC() error {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
||||
listener, err := new(net.ListenConfig).Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,8 +109,10 @@ func mockOIDC() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info().Msgf("Mock OIDC server listening on %s", listener.Addr().String())
|
||||
log.Info().Msgf("Issuer: %s", mock.Issuer())
|
||||
|
||||
log.Info().Msgf("mock OIDC server listening on %s", listener.Addr().String())
|
||||
log.Info().Msgf("issuer: %s", mock.Issuer())
|
||||
|
||||
c := make(chan struct{})
|
||||
<-c
|
||||
|
||||
@@ -133,12 +143,13 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
|
||||
ErrorQueue: &mockoidc.ErrorQueue{},
|
||||
}
|
||||
|
||||
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||
_ = mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info().Msgf("Request: %+v", r)
|
||||
log.Info().Msgf("request: %+v", r)
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if r.Response != nil {
|
||||
log.Info().Msgf("Response: %+v", r.Response)
|
||||
log.Info().Msgf("response: %+v", r.Response)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -26,6 +26,7 @@ func init() {
|
||||
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
|
||||
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
listNodesNamespaceFlag.Hidden = true
|
||||
|
||||
nodeCmd.AddCommand(listNodesCmd)
|
||||
|
||||
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
@@ -42,42 +43,51 @@ func init() {
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
registerNodeCmd.Flags().StringP("key", "k", "", "Key")
|
||||
|
||||
err = registerNodeCmd.MarkFlagRequired("key")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(registerNodeCmd)
|
||||
|
||||
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
|
||||
|
||||
err = expireNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(expireNodeCmd)
|
||||
|
||||
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
|
||||
err = renameNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(renameNodeCmd)
|
||||
|
||||
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
|
||||
err = deleteNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(deleteNodeCmd)
|
||||
|
||||
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
tagCmd.MarkFlagRequired("identifier")
|
||||
_ = tagCmd.MarkFlagRequired("identifier")
|
||||
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
|
||||
nodeCmd.AddCommand(tagCmd)
|
||||
|
||||
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
_ = approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
|
||||
nodeCmd.AddCommand(approveRoutesCmd)
|
||||
|
||||
@@ -233,10 +243,7 @@ var listNodeRoutesCmd = &cobra.Command{
|
||||
return
|
||||
}
|
||||
|
||||
tableData, err := nodeRoutesToPtables(nodes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
}
|
||||
tableData := nodeRoutesToPtables(nodes)
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
@@ -506,15 +513,21 @@ func nodesToPtables(
|
||||
ephemeral = true
|
||||
}
|
||||
|
||||
var lastSeen time.Time
|
||||
var lastSeenTime string
|
||||
var (
|
||||
lastSeen time.Time
|
||||
lastSeenTime string
|
||||
)
|
||||
|
||||
if node.GetLastSeen() != nil {
|
||||
lastSeen = node.GetLastSeen().AsTime()
|
||||
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
var expiry time.Time
|
||||
var expiryTime string
|
||||
var (
|
||||
expiry time.Time
|
||||
expiryTime string
|
||||
)
|
||||
|
||||
if node.GetExpiry() != nil {
|
||||
expiry = node.GetExpiry().AsTime()
|
||||
expiryTime = expiry.Format("2006-01-02 15:04:05")
|
||||
@@ -523,6 +536,7 @@ func nodesToPtables(
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
|
||||
err := machineKey.UnmarshalText(
|
||||
[]byte(node.GetMachineKey()),
|
||||
)
|
||||
@@ -531,6 +545,7 @@ func nodesToPtables(
|
||||
}
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
|
||||
err = nodeKey.UnmarshalText(
|
||||
[]byte(node.GetNodeKey()),
|
||||
)
|
||||
@@ -572,8 +587,11 @@ func nodesToPtables(
|
||||
user = pterm.LightYellow(node.GetUser().GetName())
|
||||
}
|
||||
|
||||
var IPV4Address string
|
||||
var IPV6Address string
|
||||
var (
|
||||
IPV4Address string
|
||||
IPV6Address string
|
||||
)
|
||||
|
||||
for _, addr := range node.GetIpAddresses() {
|
||||
if netip.MustParseAddr(addr).Is4() {
|
||||
IPV4Address = addr
|
||||
@@ -608,7 +626,7 @@ func nodesToPtables(
|
||||
|
||||
func nodeRoutesToPtables(
|
||||
nodes []*v1.Node,
|
||||
) (pterm.TableData, error) {
|
||||
) pterm.TableData {
|
||||
tableHeader := []string{
|
||||
"ID",
|
||||
"Hostname",
|
||||
@@ -632,7 +650,7 @@ func nodeRoutesToPtables(
|
||||
)
|
||||
}
|
||||
|
||||
return tableData, nil
|
||||
return tableData
|
||||
}
|
||||
|
||||
var tagCmd = &cobra.Command{
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
bypassFlag = "bypass-grpc-and-access-database-directly"
|
||||
bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -26,16 +26,22 @@ func init() {
|
||||
policyCmd.AddCommand(getPolicy)
|
||||
|
||||
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||
if err := setPolicy.MarkFlagRequired("file"); err != nil {
|
||||
|
||||
err := setPolicy.MarkFlagRequired("file")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
|
||||
policyCmd.AddCommand(setPolicy)
|
||||
|
||||
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
|
||||
|
||||
err = checkPolicy.MarkFlagRequired("file")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
policyCmd.AddCommand(checkPolicy)
|
||||
}
|
||||
|
||||
@@ -173,7 +179,7 @@ var setPolicy = &cobra.Command{
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,15 +45,16 @@ func initConfig() {
|
||||
if cfgFile == "" {
|
||||
cfgFile = os.Getenv("HEADSCALE_CONFIG")
|
||||
}
|
||||
|
||||
if cfgFile != "" {
|
||||
err := types.LoadConfig(cfgFile, true)
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile)
|
||||
log.Fatal().Caller().Err(err).Msgf("error loading config file %s", cfgFile)
|
||||
}
|
||||
} else {
|
||||
err := types.LoadConfig("", false)
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msgf("Error loading config")
|
||||
log.Fatal().Caller().Err(err).Msgf("error loading config")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +81,7 @@ func initConfig() {
|
||||
Repository: "headscale",
|
||||
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
|
||||
}
|
||||
|
||||
res, err := latest.Check(githubTag, versionInfo.Version)
|
||||
if err == nil && res.Outdated {
|
||||
//nolint
|
||||
@@ -101,6 +103,7 @@ func isPreReleaseVersion(version string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -140,7 +143,8 @@ https://github.com/juanfont/headscale`,
|
||||
}
|
||||
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -23,18 +23,17 @@ var serveCmd = &cobra.Command{
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
app, err := newHeadscaleServerWithConfig()
|
||||
if err != nil {
|
||||
var squibbleErr squibble.ValidationError
|
||||
if errors.As(err, &squibbleErr) {
|
||||
if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok {
|
||||
fmt.Printf("SQLite schema failed to validate:\n")
|
||||
fmt.Println(squibbleErr.Diff)
|
||||
}
|
||||
|
||||
log.Fatal().Caller().Err(err).Msg("Error initializing")
|
||||
log.Fatal().Caller().Err(err).Msg("error initializing")
|
||||
}
|
||||
|
||||
err = app.Serve()
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.")
|
||||
log.Fatal().Caller().Err(err).Msg("headscale ran into an error and had to shut down")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -8,12 +8,19 @@ import (
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/pterm/pterm"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// CLI user errors.
|
||||
var (
|
||||
errFlagRequired = errors.New("--name or --identifier flag is required")
|
||||
errMultipleUsersMatch = errors.New("multiple users match query, specify an ID")
|
||||
)
|
||||
|
||||
func usernameAndIDFlag(cmd *cobra.Command) {
|
||||
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
|
||||
cmd.Flags().StringP("name", "n", "", "Username")
|
||||
@@ -23,12 +30,12 @@ func usernameAndIDFlag(cmd *cobra.Command) {
|
||||
// If both are empty, it will exit the program with an error.
|
||||
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||
username, _ := cmd.Flags().GetString("name")
|
||||
|
||||
identifier, _ := cmd.Flags().GetInt64("identifier")
|
||||
if username == "" && identifier < 0 {
|
||||
err := errors.New("--name or --identifier flag is required")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
errFlagRequired,
|
||||
"Cannot rename user: "+status.Convert(errFlagRequired).Message(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
@@ -50,7 +57,8 @@ func init() {
|
||||
userCmd.AddCommand(renameUserCmd)
|
||||
usernameAndIDFlag(renameUserCmd)
|
||||
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
|
||||
renameNodeCmd.MarkFlagRequired("new-name")
|
||||
|
||||
_ = renameNodeCmd.MarkFlagRequired("new-name")
|
||||
}
|
||||
|
||||
var errMissingParameter = errors.New("missing parameters")
|
||||
@@ -81,7 +89,7 @@ var createUserCmd = &cobra.Command{
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
|
||||
log.Trace().Interface(zf.Client, client).Msg("obtained gRPC client")
|
||||
|
||||
request := &v1.CreateUserRequest{Name: userName}
|
||||
|
||||
@@ -94,7 +102,7 @@ var createUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
|
||||
if _, err := url.Parse(pictureURL); err != nil {
|
||||
if _, err := url.Parse(pictureURL); err != nil { //nolint:noinlineerr
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
@@ -107,7 +115,7 @@ var createUserCmd = &cobra.Command{
|
||||
request.PictureUrl = pictureURL
|
||||
}
|
||||
|
||||
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
|
||||
log.Trace().Interface(zf.Request, request).Msg("sending CreateUser request")
|
||||
response, err := client.CreateUser(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
@@ -148,7 +156,7 @@ var destroyUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
err := errMultipleUsersMatch
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
@@ -276,7 +284,7 @@ var renameUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
err := errMultipleUsersMatch
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -57,7 +58,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
|
||||
|
||||
grpcOptions := []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
grpc.WithBlock(), //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||
}
|
||||
|
||||
address := cfg.CLI.Address
|
||||
@@ -81,6 +82,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
Msgf("Unable to read/write to headscale socket, do you have the correct permissions?")
|
||||
}
|
||||
}
|
||||
|
||||
socket.Close()
|
||||
|
||||
grpcOptions = append(
|
||||
@@ -92,8 +94,9 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
// If we are not connecting to a local server, require an API key for authentication
|
||||
apiKey := cfg.CLI.APIKey
|
||||
if apiKey == "" {
|
||||
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.")
|
||||
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set")
|
||||
}
|
||||
|
||||
grpcOptions = append(grpcOptions,
|
||||
grpc.WithPerRPCCredentials(tokenAuth{
|
||||
token: apiKey,
|
||||
@@ -118,10 +121,11 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC")
|
||||
conn, err := grpc.DialContext(ctx, address, grpcOptions...)
|
||||
log.Trace().Caller().Str(zf.Address, address).Msg("connecting via gRPC")
|
||||
|
||||
conn, err := grpc.DialContext(ctx, address, grpcOptions...) //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msgf("Could not connect: %v", err)
|
||||
log.Fatal().Caller().Err(err).Msgf("could not connect: %v", err)
|
||||
os.Exit(-1) // we get here if logging is suppressed (i.e., json output)
|
||||
}
|
||||
|
||||
@@ -131,23 +135,26 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
}
|
||||
|
||||
func output(result any, override string, outputFormat string) string {
|
||||
var jsonBytes []byte
|
||||
var err error
|
||||
var (
|
||||
jsonBytes []byte
|
||||
err error
|
||||
)
|
||||
|
||||
switch outputFormat {
|
||||
case "json":
|
||||
jsonBytes, err = json.MarshalIndent(result, "", "\t")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to unmarshal output")
|
||||
log.Fatal().Err(err).Msg("unmarshalling output")
|
||||
}
|
||||
case "json-line":
|
||||
jsonBytes, err = json.Marshal(result)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to unmarshal output")
|
||||
log.Fatal().Err(err).Msg("unmarshalling output")
|
||||
}
|
||||
case "yaml":
|
||||
jsonBytes, err = yaml.Marshal(result)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to unmarshal output")
|
||||
log.Fatal().Err(err).Msg("unmarshalling output")
|
||||
}
|
||||
default:
|
||||
// nolint
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
func main() {
|
||||
var colors bool
|
||||
|
||||
switch l := termcolor.SupportLevel(os.Stderr); l {
|
||||
case termcolor.Level16M:
|
||||
colors = true
|
||||
|
||||
@@ -14,9 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestConfigFileLoading(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
path, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
@@ -48,9 +46,7 @@ func TestConfigFileLoading(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigLoading(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
path, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -22,11 +22,11 @@ import (
|
||||
func cleanupBeforeTest(ctx context.Context) error {
|
||||
err := cleanupStaleTestContainers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean stale test containers: %w", err)
|
||||
return fmt.Errorf("cleaning stale test containers: %w", err)
|
||||
}
|
||||
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
return fmt.Errorf("failed to prune networks: %w", err)
|
||||
if err := pruneDockerNetworks(ctx); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("pruning networks: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -39,14 +39,14 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
|
||||
Force: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove test container: %w", err)
|
||||
return fmt.Errorf("removing test container: %w", err)
|
||||
}
|
||||
|
||||
// Clean up integration test containers for this run only
|
||||
if runID != "" {
|
||||
err := killTestContainersByRunID(ctx, runID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean up containers for run %s: %w", runID, err)
|
||||
return fmt.Errorf("cleaning up containers for run %s: %w", runID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,9 +55,9 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
|
||||
|
||||
// killTestContainers terminates and removes all test containers.
|
||||
func killTestContainers(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
@@ -65,12 +65,14 @@ func killTestContainers(ctx context.Context) error {
|
||||
All: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list containers: %w", err)
|
||||
return fmt.Errorf("listing containers: %w", err)
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, cont := range containers {
|
||||
shouldRemove := false
|
||||
|
||||
for _, name := range cont.Names {
|
||||
if strings.Contains(name, "headscale-test-suite") ||
|
||||
strings.Contains(name, "hs-") ||
|
||||
@@ -107,9 +109,9 @@ func killTestContainers(ctx context.Context) error {
|
||||
// This function filters containers by the hi.run-id label to only affect containers
|
||||
// belonging to the specified test run, leaving other concurrent test runs untouched.
|
||||
func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
@@ -121,7 +123,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||
),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list containers for run %s: %w", runID, err)
|
||||
return fmt.Errorf("listing containers for run %s: %w", runID, err)
|
||||
}
|
||||
|
||||
removed := 0
|
||||
@@ -149,9 +151,9 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
|
||||
// without interfering with currently running concurrent tests.
|
||||
func cleanupStaleTestContainers(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
@@ -164,7 +166,7 @@ func cleanupStaleTestContainers(ctx context.Context) error {
|
||||
),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list stopped containers: %w", err)
|
||||
return fmt.Errorf("listing stopped containers: %w", err)
|
||||
}
|
||||
|
||||
removed := 0
|
||||
@@ -223,15 +225,15 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container
|
||||
|
||||
// pruneDockerNetworks removes unused Docker networks.
|
||||
func pruneDockerNetworks(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
report, err := cli.NetworksPrune(ctx, filters.Args{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prune networks: %w", err)
|
||||
return fmt.Errorf("pruning networks: %w", err)
|
||||
}
|
||||
|
||||
if len(report.NetworksDeleted) > 0 {
|
||||
@@ -245,9 +247,9 @@ func pruneDockerNetworks(ctx context.Context) error {
|
||||
|
||||
// cleanOldImages removes test-related and old dangling Docker images.
|
||||
func cleanOldImages(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
@@ -255,12 +257,14 @@ func cleanOldImages(ctx context.Context) error {
|
||||
All: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list images: %w", err)
|
||||
return fmt.Errorf("listing images: %w", err)
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, img := range images {
|
||||
shouldRemove := false
|
||||
|
||||
for _, tag := range img.RepoTags {
|
||||
if strings.Contains(tag, "hs-") ||
|
||||
strings.Contains(tag, "headscale-integration") ||
|
||||
@@ -295,18 +299,19 @@ func cleanOldImages(ctx context.Context) error {
|
||||
|
||||
// cleanCacheVolume removes the Docker volume used for Go module cache.
|
||||
func cleanCacheVolume(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
volumeName := "hs-integration-go-cache"
|
||||
|
||||
err = cli.VolumeRemove(ctx, volumeName, true)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
if errdefs.IsNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||
fmt.Printf("Go module cache volume not found: %s\n", volumeName)
|
||||
} else if errdefs.IsConflict(err) {
|
||||
} else if errdefs.IsConflict(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
|
||||
} else {
|
||||
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)
|
||||
@@ -330,7 +335,7 @@ func cleanCacheVolume(ctx context.Context) error {
|
||||
func cleanupSuccessfulTestArtifacts(logsDir string, verbose bool) error {
|
||||
entries, err := os.ReadDir(logsDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read logs directory: %w", err)
|
||||
return fmt.Errorf("reading logs directory: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
148
cmd/hi/docker.go
148
cmd/hi/docker.go
@@ -22,17 +22,22 @@ import (
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
)
|
||||
|
||||
const defaultDirPerm = 0o755
|
||||
|
||||
var (
|
||||
ErrTestFailed = errors.New("test failed")
|
||||
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
|
||||
ErrNoDockerContext = errors.New("no docker context found")
|
||||
ErrMemoryLimitViolations = errors.New("container(s) exceeded memory limits")
|
||||
)
|
||||
|
||||
// runTestContainer executes integration tests in a Docker container.
|
||||
//
|
||||
//nolint:gocyclo // complex test orchestration function
|
||||
func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
@@ -48,19 +53,21 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
absLogsDir, err := filepath.Abs(logsDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path for logs directory: %w", err)
|
||||
return fmt.Errorf("getting absolute path for logs directory: %w", err)
|
||||
}
|
||||
|
||||
const dirPerm = 0o755
|
||||
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("creating logs directory: %w", err)
|
||||
}
|
||||
|
||||
if config.CleanBefore {
|
||||
if config.Verbose {
|
||||
log.Printf("Running pre-test cleanup...")
|
||||
}
|
||||
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
|
||||
|
||||
err := cleanupBeforeTest(ctx)
|
||||
if err != nil && config.Verbose {
|
||||
log.Printf("Warning: pre-test cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -71,21 +78,21 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
}
|
||||
|
||||
imageName := "golang:" + config.GoVersion
|
||||
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil {
|
||||
return fmt.Errorf("failed to ensure image availability: %w", err)
|
||||
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("ensuring image availability: %w", err)
|
||||
}
|
||||
|
||||
resp, err := createGoTestContainer(ctx, cli, config, containerName, absLogsDir, goTestCmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create container: %w", err)
|
||||
return fmt.Errorf("creating container: %w", err)
|
||||
}
|
||||
|
||||
if config.Verbose {
|
||||
log.Printf("Created container: %s", resp.ID)
|
||||
}
|
||||
|
||||
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
return fmt.Errorf("failed to start container: %w", err)
|
||||
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("starting container: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Starting test: %s", config.TestPattern)
|
||||
@@ -95,13 +102,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
// Start stats collection for container resource monitoring (if enabled)
|
||||
var statsCollector *StatsCollector
|
||||
|
||||
if config.Stats {
|
||||
var err error
|
||||
statsCollector, err = NewStatsCollector()
|
||||
|
||||
statsCollector, err = NewStatsCollector(ctx)
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||
}
|
||||
|
||||
statsCollector = nil
|
||||
}
|
||||
|
||||
@@ -110,7 +120,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
// Start stats collection immediately - no need for complex retry logic
|
||||
// The new implementation monitors Docker events and will catch containers as they start
|
||||
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
||||
err := statsCollector.StartCollection(ctx, runID, config.Verbose)
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||
}
|
||||
@@ -122,12 +133,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||
|
||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||
if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose {
|
||||
waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose)
|
||||
if waitErr != nil && config.Verbose {
|
||||
log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
|
||||
}
|
||||
|
||||
// Extract artifacts from test containers before cleanup
|
||||
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose {
|
||||
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { //nolint:noinlineerr
|
||||
log.Printf("Warning: failed to extract artifacts from containers: %v", err)
|
||||
}
|
||||
|
||||
@@ -140,12 +152,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
if len(violations) > 0 {
|
||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||
log.Printf("=================================")
|
||||
|
||||
for _, violation := range violations {
|
||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||
}
|
||||
|
||||
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
||||
return fmt.Errorf("test failed: %d %w", len(violations), ErrMemoryLimitViolations)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,7 +189,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("test execution failed: %w", err)
|
||||
return fmt.Errorf("executing test: %w", err)
|
||||
}
|
||||
|
||||
if exitCode != 0 {
|
||||
@@ -210,7 +223,7 @@ func buildGoTestCommand(config *RunConfig) []string {
|
||||
func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunConfig, containerName, logsDir string, goTestCmd []string) (container.CreateResponse, error) {
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return container.CreateResponse{}, fmt.Errorf("failed to get working directory: %w", err)
|
||||
return container.CreateResponse{}, fmt.Errorf("getting working directory: %w", err)
|
||||
}
|
||||
|
||||
projectRoot := findProjectRoot(pwd)
|
||||
@@ -312,7 +325,7 @@ func streamAndWait(ctx context.Context, cli *client.Client, containerID string)
|
||||
Follow: true,
|
||||
})
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to get container logs: %w", err)
|
||||
return -1, fmt.Errorf("getting container logs: %w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
@@ -324,7 +337,7 @@ func streamAndWait(ctx context.Context, cli *client.Client, containerID string)
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("error waiting for container: %w", err)
|
||||
return -1, fmt.Errorf("waiting for container: %w", err)
|
||||
}
|
||||
case status := <-statusCh:
|
||||
return int(status.StatusCode), nil
|
||||
@@ -338,7 +351,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
// First, get all related test containers
|
||||
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list containers: %w", err)
|
||||
return fmt.Errorf("listing containers: %w", err)
|
||||
}
|
||||
|
||||
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
@@ -347,6 +360,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
maxWaitTime := 10 * time.Second
|
||||
checkInterval := 500 * time.Millisecond
|
||||
timeout := time.After(maxWaitTime)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -356,6 +370,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
if verbose {
|
||||
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
allFinalized := true
|
||||
@@ -366,12 +381,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if container is in a final state
|
||||
if !isContainerFinalized(inspect.State) {
|
||||
allFinalized = false
|
||||
|
||||
if verbose {
|
||||
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
||||
}
|
||||
@@ -384,6 +401,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
if verbose {
|
||||
log.Printf("All test containers finalized, ready for artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -400,13 +418,15 @@ func isContainerFinalized(state *container.State) bool {
|
||||
func findProjectRoot(startPath string) string {
|
||||
current := startPath
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
|
||||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { //nolint:noinlineerr
|
||||
return current
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
return startPath
|
||||
}
|
||||
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
@@ -416,6 +436,7 @@ func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -428,13 +449,14 @@ type DockerContext struct {
|
||||
}
|
||||
|
||||
// createDockerClient creates a Docker client with context detection.
|
||||
func createDockerClient() (*client.Client, error) {
|
||||
contextInfo, err := getCurrentDockerContext()
|
||||
func createDockerClient(ctx context.Context) (*client.Client, error) {
|
||||
contextInfo, err := getCurrentDockerContext(ctx)
|
||||
if err != nil {
|
||||
return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
}
|
||||
|
||||
var clientOpts []client.Opt
|
||||
|
||||
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
|
||||
|
||||
if contextInfo != nil {
|
||||
@@ -444,6 +466,7 @@ func createDockerClient() (*client.Client, error) {
|
||||
if runConfig.Verbose {
|
||||
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
||||
}
|
||||
|
||||
clientOpts = append(clientOpts, client.WithHost(host))
|
||||
}
|
||||
}
|
||||
@@ -458,16 +481,17 @@ func createDockerClient() (*client.Client, error) {
|
||||
}
|
||||
|
||||
// getCurrentDockerContext retrieves the current Docker context information.
|
||||
func getCurrentDockerContext() (*DockerContext, error) {
|
||||
cmd := exec.Command("docker", "context", "inspect")
|
||||
func getCurrentDockerContext(ctx context.Context) (*DockerContext, error) {
|
||||
cmd := exec.CommandContext(ctx, "docker", "context", "inspect")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get docker context: %w", err)
|
||||
return nil, fmt.Errorf("getting docker context: %w", err)
|
||||
}
|
||||
|
||||
var contexts []DockerContext
|
||||
if err := json.Unmarshal(output, &contexts); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse docker context: %w", err)
|
||||
if err := json.Unmarshal(output, &contexts); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("parsing docker context: %w", err)
|
||||
}
|
||||
|
||||
if len(contexts) > 0 {
|
||||
@@ -486,12 +510,13 @@ func getDockerSocketPath() string {
|
||||
|
||||
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName) //nolint:staticcheck // SA1019: deprecated but functional
|
||||
if err != nil {
|
||||
if client.IsErrNotFound(err) {
|
||||
if client.IsErrNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
||||
|
||||
return false, fmt.Errorf("inspecting image %s: %w", imageName, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
@@ -502,13 +527,14 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
||||
// First check if image is available locally
|
||||
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check local image availability: %w", err)
|
||||
return fmt.Errorf("checking local image availability: %w", err)
|
||||
}
|
||||
|
||||
if available {
|
||||
if verbose {
|
||||
log.Printf("Image %s is available locally", imageName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -519,20 +545,21 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
||||
|
||||
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pull image %s: %w", imageName, err)
|
||||
return fmt.Errorf("pulling image %s: %w", imageName, err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if verbose {
|
||||
_, err = io.Copy(os.Stdout, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read pull output: %w", err)
|
||||
return fmt.Errorf("reading pull output: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err = io.Copy(io.Discard, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read pull output: %w", err)
|
||||
return fmt.Errorf("reading pull output: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Image %s pulled successfully", imageName)
|
||||
}
|
||||
|
||||
@@ -547,9 +574,11 @@ func listControlFiles(logsDir string) {
|
||||
return
|
||||
}
|
||||
|
||||
var logFiles []string
|
||||
var dataFiles []string
|
||||
var dataDirs []string
|
||||
var (
|
||||
logFiles []string
|
||||
dataFiles []string
|
||||
dataDirs []string
|
||||
)
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
@@ -578,6 +607,7 @@ func listControlFiles(logsDir string) {
|
||||
|
||||
if len(logFiles) > 0 {
|
||||
log.Printf("Headscale logs:")
|
||||
|
||||
for _, file := range logFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
@@ -585,9 +615,11 @@ func listControlFiles(logsDir string) {
|
||||
|
||||
if len(dataFiles) > 0 || len(dataDirs) > 0 {
|
||||
log.Printf("Headscale data:")
|
||||
|
||||
for _, file := range dataFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
|
||||
for _, dir := range dataDirs {
|
||||
log.Printf(" %s/", dir)
|
||||
}
|
||||
@@ -596,25 +628,27 @@ func listControlFiles(logsDir string) {
|
||||
|
||||
// extractArtifactsFromContainers collects container logs and files from the specific test run.
|
||||
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
// List all containers
|
||||
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list containers: %w", err)
|
||||
return fmt.Errorf("listing containers: %w", err)
|
||||
}
|
||||
|
||||
// Get containers from the specific test run
|
||||
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
|
||||
extractedCount := 0
|
||||
|
||||
for _, cont := range currentTestContainers {
|
||||
// Extract container logs and tar files
|
||||
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
|
||||
err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
|
||||
}
|
||||
@@ -622,6 +656,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
||||
if verbose {
|
||||
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
|
||||
}
|
||||
|
||||
extractedCount++
|
||||
}
|
||||
}
|
||||
@@ -645,11 +680,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
||||
|
||||
// Find the test container to get its run ID label
|
||||
var runID string
|
||||
|
||||
for _, cont := range containers {
|
||||
if cont.ID == testContainerID {
|
||||
if cont.Labels != nil {
|
||||
runID = cont.Labels["hi.run-id"]
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -690,18 +727,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
||||
// extractContainerArtifacts saves logs and tar files from a container.
|
||||
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
||||
// Ensure the logs directory exists
|
||||
if err := os.MkdirAll(logsDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
err := os.MkdirAll(logsDir, defaultDirPerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract container logs
|
||||
if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
||||
return fmt.Errorf("failed to extract logs: %w", err)
|
||||
err = extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extracting logs: %w", err)
|
||||
}
|
||||
|
||||
// Extract tar files for headscale containers only
|
||||
if strings.HasPrefix(containerName, "hs-") {
|
||||
if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
||||
err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
|
||||
}
|
||||
@@ -723,7 +763,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
||||
Tail: "all",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get container logs: %w", err)
|
||||
return fmt.Errorf("getting container logs: %w", err)
|
||||
}
|
||||
defer logReader.Close()
|
||||
|
||||
@@ -737,17 +777,17 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
||||
// Demultiplex the Docker logs stream to separate stdout and stderr
|
||||
_, err = stdcopy.StdCopy(&stdoutBuf, &stderrBuf, logReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to demultiplex container logs: %w", err)
|
||||
return fmt.Errorf("demultiplexing container logs: %w", err)
|
||||
}
|
||||
|
||||
// Write stdout logs
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write stdout log: %w", err)
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
|
||||
return fmt.Errorf("writing stdout log: %w", err)
|
||||
}
|
||||
|
||||
// Write stderr logs
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write stderr log: %w", err)
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
|
||||
return fmt.Errorf("writing stderr log: %w", err)
|
||||
}
|
||||
|
||||
if verbose {
|
||||
|
||||
@@ -38,13 +38,13 @@ func runDoctorCheck(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Check 3: Go installation
|
||||
results = append(results, checkGoInstallation())
|
||||
results = append(results, checkGoInstallation(ctx))
|
||||
|
||||
// Check 4: Git repository
|
||||
results = append(results, checkGitRepository())
|
||||
results = append(results, checkGitRepository(ctx))
|
||||
|
||||
// Check 5: Required files
|
||||
results = append(results, checkRequiredFiles())
|
||||
results = append(results, checkRequiredFiles(ctx))
|
||||
|
||||
// Display results
|
||||
displayDoctorResults(results)
|
||||
@@ -86,7 +86,7 @@ func checkDockerBinary() DoctorResult {
|
||||
|
||||
// checkDockerDaemon verifies Docker daemon is running and accessible.
|
||||
func checkDockerDaemon(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Docker Daemon",
|
||||
@@ -124,8 +124,8 @@ func checkDockerDaemon(ctx context.Context) DoctorResult {
|
||||
}
|
||||
|
||||
// checkDockerContext verifies Docker context configuration.
|
||||
func checkDockerContext(_ context.Context) DoctorResult {
|
||||
contextInfo, err := getCurrentDockerContext()
|
||||
func checkDockerContext(ctx context.Context) DoctorResult {
|
||||
contextInfo, err := getCurrentDockerContext(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Docker Context",
|
||||
@@ -155,7 +155,7 @@ func checkDockerContext(_ context.Context) DoctorResult {
|
||||
|
||||
// checkDockerSocket verifies Docker socket accessibility.
|
||||
func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Docker Socket",
|
||||
@@ -192,7 +192,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
|
||||
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
@@ -251,7 +251,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
}
|
||||
|
||||
// checkGoInstallation verifies Go is installed and working.
|
||||
func checkGoInstallation() DoctorResult {
|
||||
func checkGoInstallation(ctx context.Context) DoctorResult {
|
||||
_, err := exec.LookPath("go")
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
@@ -265,7 +265,8 @@ func checkGoInstallation() DoctorResult {
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "version")
|
||||
cmd := exec.CommandContext(ctx, "go", "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
@@ -285,8 +286,9 @@ func checkGoInstallation() DoctorResult {
|
||||
}
|
||||
|
||||
// checkGitRepository verifies we're in a git repository.
|
||||
func checkGitRepository() DoctorResult {
|
||||
cmd := exec.Command("git", "rev-parse", "--git-dir")
|
||||
func checkGitRepository(ctx context.Context) DoctorResult {
|
||||
cmd := exec.CommandContext(ctx, "git", "rev-parse", "--git-dir")
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
@@ -308,7 +310,7 @@ func checkGitRepository() DoctorResult {
|
||||
}
|
||||
|
||||
// checkRequiredFiles verifies required files exist.
|
||||
func checkRequiredFiles() DoctorResult {
|
||||
func checkRequiredFiles(ctx context.Context) DoctorResult {
|
||||
requiredFiles := []string{
|
||||
"go.mod",
|
||||
"integration/",
|
||||
@@ -316,9 +318,12 @@ func checkRequiredFiles() DoctorResult {
|
||||
}
|
||||
|
||||
var missingFiles []string
|
||||
|
||||
for _, file := range requiredFiles {
|
||||
cmd := exec.Command("test", "-e", file)
|
||||
if err := cmd.Run(); err != nil {
|
||||
cmd := exec.CommandContext(ctx, "test", "-e", file)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
missingFiles = append(missingFiles, file)
|
||||
}
|
||||
}
|
||||
@@ -350,6 +355,7 @@ func displayDoctorResults(results []DoctorResult) {
|
||||
|
||||
for _, result := range results {
|
||||
var icon string
|
||||
|
||||
switch result.Status {
|
||||
case "PASS":
|
||||
icon = "✅"
|
||||
|
||||
@@ -79,13 +79,18 @@ func main() {
|
||||
}
|
||||
|
||||
func cleanAll(ctx context.Context) error {
|
||||
if err := killTestContainers(ctx); err != nil {
|
||||
err := killTestContainers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
|
||||
err = pruneDockerNetworks(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cleanOldImages(ctx); err != nil {
|
||||
|
||||
err = cleanOldImages(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error {
|
||||
if runConfig.Verbose {
|
||||
log.Printf("Running pre-flight system checks...")
|
||||
}
|
||||
if err := runDoctorCheck(env.Context()); err != nil {
|
||||
|
||||
err := runDoctorCheck(env.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("pre-flight checks failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -66,15 +68,15 @@ func runIntegrationTest(env *command.Env) error {
|
||||
func detectGoVersion() string {
|
||||
goModPath := filepath.Join("..", "..", "go.mod")
|
||||
|
||||
if _, err := os.Stat("go.mod"); err == nil {
|
||||
if _, err := os.Stat("go.mod"); err == nil { //nolint:noinlineerr
|
||||
goModPath = "go.mod"
|
||||
} else if _, err := os.Stat("../../go.mod"); err == nil {
|
||||
} else if _, err := os.Stat("../../go.mod"); err == nil { //nolint:noinlineerr
|
||||
goModPath = "../../go.mod"
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(goModPath)
|
||||
if err != nil {
|
||||
return "1.25"
|
||||
return "1.26rc2"
|
||||
}
|
||||
|
||||
lines := splitLines(string(content))
|
||||
@@ -89,13 +91,15 @@ func detectGoVersion() string {
|
||||
}
|
||||
}
|
||||
|
||||
return "1.25"
|
||||
return "1.26rc2"
|
||||
}
|
||||
|
||||
// splitLines splits a string into lines without using strings.Split.
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
var current string
|
||||
var (
|
||||
lines []string
|
||||
current string
|
||||
)
|
||||
|
||||
for _, char := range s {
|
||||
if char == '\n' {
|
||||
|
||||
@@ -18,6 +18,9 @@ import (
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
// ErrStatsCollectionAlreadyStarted is returned when trying to start stats collection that is already running.
|
||||
var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started")
|
||||
|
||||
// ContainerStats represents statistics for a single container.
|
||||
type ContainerStats struct {
|
||||
ContainerID string
|
||||
@@ -44,10 +47,10 @@ type StatsCollector struct {
|
||||
}
|
||||
|
||||
// NewStatsCollector creates a new stats collector instance.
|
||||
func NewStatsCollector() (*StatsCollector, error) {
|
||||
cli, err := createDockerClient()
|
||||
func NewStatsCollector(ctx context.Context) (*StatsCollector, error) {
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Docker client: %w", err)
|
||||
return nil, fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
|
||||
return &StatsCollector{
|
||||
@@ -63,17 +66,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
||||
defer sc.mutex.Unlock()
|
||||
|
||||
if sc.collectionStarted {
|
||||
return errors.New("stats collection already started")
|
||||
return ErrStatsCollectionAlreadyStarted
|
||||
}
|
||||
|
||||
sc.collectionStarted = true
|
||||
|
||||
// Start monitoring existing containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||
|
||||
// Start Docker events monitoring for new containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||
|
||||
if verbose {
|
||||
@@ -87,10 +92,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
||||
func (sc *StatsCollector) StopCollection() {
|
||||
// Check if already stopped without holding lock
|
||||
sc.mutex.RLock()
|
||||
|
||||
if !sc.collectionStarted {
|
||||
sc.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
// Signal stop to all goroutines
|
||||
@@ -114,6 +121,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
|
||||
if verbose {
|
||||
log.Printf("Failed to list existing containers: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -147,13 +155,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
||||
case event := <-events:
|
||||
if event.Type == "container" && event.Action == "start" {
|
||||
// Get container details
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) //nolint:staticcheck // SA1019: use Actor.ID
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to types.Container format for consistency
|
||||
cont := types.Container{
|
||||
cont := types.Container{ //nolint:staticcheck // SA1019: use container.Summary
|
||||
ID: containerInfo.ID,
|
||||
Names: []string{containerInfo.Name},
|
||||
Labels: containerInfo.Config.Labels,
|
||||
@@ -167,13 +175,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
||||
if verbose {
|
||||
log.Printf("Error in Docker events stream: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldMonitorContainer determines if a container should be monitored.
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { //nolint:staticcheck // SA1019: use container.Summary
|
||||
// Check if it has the correct run ID label
|
||||
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||
return false
|
||||
@@ -213,6 +222,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
|
||||
}
|
||||
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||
}
|
||||
|
||||
@@ -226,12 +236,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
if verbose {
|
||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
defer statsResponse.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(statsResponse.Body)
|
||||
var prevStats *container.Stats
|
||||
|
||||
var prevStats *container.Stats //nolint:staticcheck // SA1019: use StatsResponse
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -240,12 +252,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
var stats container.Stats
|
||||
if err := decoder.Decode(&stats); err != nil {
|
||||
var stats container.Stats //nolint:staticcheck // SA1019: use StatsResponse
|
||||
|
||||
err := decoder.Decode(&stats)
|
||||
if err != nil {
|
||||
// EOF is expected when container stops or stream ends
|
||||
if err.Error() != "EOF" && verbose {
|
||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -261,8 +276,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||
if prevStats != nil {
|
||||
// Get container stats reference without holding the main mutex
|
||||
var containerStats *ContainerStats
|
||||
var exists bool
|
||||
var (
|
||||
containerStats *ContainerStats
|
||||
exists bool
|
||||
)
|
||||
|
||||
sc.mutex.RLock()
|
||||
containerStats, exists = sc.containers[containerID]
|
||||
@@ -286,7 +303,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
}
|
||||
|
||||
// calculateCPUPercent calculates CPU usage percentage from Docker stats.
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 { //nolint:staticcheck // SA1019: use StatsResponse
|
||||
// CPU calculation based on Docker's implementation
|
||||
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||
@@ -331,10 +348,12 @@ type StatsSummary struct {
|
||||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||
// Take snapshot of container references without holding main lock long
|
||||
sc.mutex.RLock()
|
||||
|
||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||
for _, containerStats := range sc.containers {
|
||||
containerRefs = append(containerRefs, containerStats)
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||
@@ -384,23 +403,25 @@ func calculateStatsSummary(values []float64) StatsSummary {
|
||||
return StatsSummary{}
|
||||
}
|
||||
|
||||
min := values[0]
|
||||
max := values[0]
|
||||
minVal := values[0]
|
||||
maxVal := values[0]
|
||||
sum := 0.0
|
||||
|
||||
for _, value := range values {
|
||||
if value < min {
|
||||
min = value
|
||||
if value < minVal {
|
||||
minVal = value
|
||||
}
|
||||
if value > max {
|
||||
max = value
|
||||
|
||||
if value > maxVal {
|
||||
maxVal = value
|
||||
}
|
||||
|
||||
sum += value
|
||||
}
|
||||
|
||||
return StatsSummary{
|
||||
Min: min,
|
||||
Max: max,
|
||||
Min: minVal,
|
||||
Max: maxVal,
|
||||
Average: sum / float64(len(values)),
|
||||
}
|
||||
}
|
||||
@@ -434,6 +455,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
|
||||
}
|
||||
|
||||
summaries := sc.GetSummary()
|
||||
|
||||
var violations []MemoryViolation
|
||||
|
||||
for _, summary := range summaries {
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -15,7 +16,10 @@ type MapConfig struct {
|
||||
Directory string `flag:"directory,Directory to read map responses from"`
|
||||
}
|
||||
|
||||
var mapConfig MapConfig
|
||||
var (
|
||||
mapConfig MapConfig
|
||||
errDirectoryRequired = errors.New("directory is required")
|
||||
)
|
||||
|
||||
func main() {
|
||||
root := command.C{
|
||||
@@ -40,7 +44,7 @@ func main() {
|
||||
// runIntegrationTest executes the integration test workflow.
|
||||
func runOnline(env *command.Env) error {
|
||||
if mapConfig.Directory == "" {
|
||||
return fmt.Errorf("directory is required")
|
||||
return errDirectoryRequired
|
||||
}
|
||||
|
||||
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
|
||||
@@ -57,5 +61,6 @@ func runOnline(env *command.Env) error {
|
||||
|
||||
os.Stderr.Write(out)
|
||||
os.Stderr.Write([]byte("\n"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
54
flake.nix
54
flake.nix
@@ -26,8 +26,8 @@
|
||||
overlays.default = _: prev:
|
||||
let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
||||
buildGo = pkgs.buildGo125Module;
|
||||
vendorHash = "sha256-jkeB9XUTEGt58fPOMpE4/e3+JQoMQTgf0RlthVBmfG0=";
|
||||
buildGo = pkgs.buildGo126Module;
|
||||
vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
|
||||
in
|
||||
{
|
||||
headscale = buildGo {
|
||||
@@ -94,14 +94,46 @@
|
||||
subPackages = [ "." ];
|
||||
};
|
||||
|
||||
# Upstream does not override buildGoModule properly,
|
||||
# importing a specific module, so comment out for now.
|
||||
# golangci-lint = prev.golangci-lint.override {
|
||||
# buildGoModule = buildGo;
|
||||
# };
|
||||
# golangci-lint-langserver = prev.golangci-lint.override {
|
||||
# buildGoModule = buildGo;
|
||||
# };
|
||||
# Build golangci-lint with Go 1.26 (upstream uses hardcoded Go version)
|
||||
golangci-lint = buildGo rec {
|
||||
pname = "golangci-lint";
|
||||
version = "2.8.0";
|
||||
|
||||
src = pkgs.fetchFromGitHub {
|
||||
owner = "golangci";
|
||||
repo = "golangci-lint";
|
||||
rev = "v${version}";
|
||||
hash = "sha256-w6MAOirj8rPHYbKrW4gJeemXCS64fNtteV6IioqIQTQ=";
|
||||
};
|
||||
|
||||
vendorHash = "sha256-/Vqo/yrmGh6XipELQ9NDtlMEO2a654XykmvnMs0BdrI=";
|
||||
|
||||
subPackages = [ "cmd/golangci-lint" ];
|
||||
|
||||
nativeBuildInputs = [ pkgs.installShellFiles ];
|
||||
|
||||
ldflags = [
|
||||
"-s"
|
||||
"-w"
|
||||
"-X main.version=${version}"
|
||||
"-X main.commit=v${version}"
|
||||
"-X main.date=1970-01-01T00:00:00Z"
|
||||
];
|
||||
|
||||
postInstall = ''
|
||||
for shell in bash zsh fish; do
|
||||
HOME=$TMPDIR $out/bin/golangci-lint completion $shell > golangci-lint.$shell
|
||||
installShellCompletion golangci-lint.$shell
|
||||
done
|
||||
'';
|
||||
|
||||
meta = {
|
||||
description = "Fast linters runner for Go";
|
||||
homepage = "https://golangci-lint.run/";
|
||||
changelog = "https://github.com/golangci/golangci-lint/blob/v${version}/CHANGELOG.md";
|
||||
mainProgram = "golangci-lint";
|
||||
};
|
||||
};
|
||||
|
||||
# The package uses buildGo125Module, not the convention.
|
||||
# goreleaser = prev.goreleaser.override {
|
||||
@@ -132,7 +164,7 @@
|
||||
overlays = [ self.overlays.default ];
|
||||
inherit system;
|
||||
};
|
||||
buildDeps = with pkgs; [ git go_1_25 gnumake ];
|
||||
buildDeps = with pkgs; [ git go_1_26 gnumake ];
|
||||
devDeps = with pkgs;
|
||||
buildDeps
|
||||
++ [
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/juanfont/headscale
|
||||
|
||||
go 1.25.5
|
||||
go 1.26rc2
|
||||
|
||||
require (
|
||||
github.com/arl/statsviz v0.8.0
|
||||
|
||||
137
hscontrol/app.go
137
hscontrol/app.go
@@ -115,13 +115,14 @@ var (
|
||||
|
||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
var err error
|
||||
|
||||
if profilingEnabled {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
}
|
||||
|
||||
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
||||
return nil, fmt.Errorf("reading or creating Noise protocol private key: %w", err)
|
||||
}
|
||||
|
||||
s, err := state.NewState(cfg)
|
||||
@@ -140,27 +141,30 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||
node, ok := app.state.GetNodeByID(ni)
|
||||
if !ok {
|
||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
|
||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed because node not found in NodeStore")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
policyChanged, err := app.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed")
|
||||
log.Error().Err(err).EmbedObject(node).Msg("ephemeral node deletion failed")
|
||||
return
|
||||
}
|
||||
|
||||
app.Change(policyChanged)
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached")
|
||||
log.Debug().Caller().EmbedObject(node).Msg("ephemeral node deleted because garbage collection timeout reached")
|
||||
})
|
||||
app.ephemeralGC = ephemeralGC
|
||||
|
||||
var authProvider AuthProvider
|
||||
|
||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
@@ -177,17 +181,18 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
authProvider = oidcProvider
|
||||
}
|
||||
}
|
||||
|
||||
app.authProvider = authProvider
|
||||
|
||||
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
||||
// TODO(kradalby): revisit why this takes a list.
|
||||
|
||||
var magicDNSDomains []dnsname.FQDN
|
||||
if cfg.PrefixV4 != nil {
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
}
|
||||
|
||||
if cfg.PrefixV6 != nil {
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
@@ -198,6 +203,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
if app.cfg.TailcfgDNSConfig.Routes == nil {
|
||||
app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
||||
}
|
||||
|
||||
for _, d := range magicDNSDomains {
|
||||
app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||
}
|
||||
@@ -206,7 +212,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
if cfg.DERP.ServerEnabled {
|
||||
derpServerKey, err := readOrCreatePrivateKey(cfg.DERP.ServerPrivateKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read or create DERP server private key: %w", err)
|
||||
return nil, fmt.Errorf("reading or creating DERP server private key: %w", err)
|
||||
}
|
||||
|
||||
if derpServerKey.Equal(*noisePrivateKey) {
|
||||
@@ -232,6 +238,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app.DERPServer = embeddedDERPServer
|
||||
}
|
||||
|
||||
@@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
lastExpiryCheck := time.Unix(0, 0)
|
||||
|
||||
derpTickerChan := make(<-chan time.Time)
|
||||
|
||||
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
||||
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
||||
defer derpTicker.Stop()
|
||||
|
||||
derpTickerChan = derpTicker.C
|
||||
}
|
||||
|
||||
@@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var expiredNodeChanges []change.Change
|
||||
var changed bool
|
||||
var (
|
||||
expiredNodeChanges []change.Change
|
||||
changed bool
|
||||
)
|
||||
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
@@ -286,12 +297,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
}
|
||||
|
||||
case <-derpTickerChan:
|
||||
log.Info().Msg("Fetching DERPMap updates")
|
||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
|
||||
log.Info().Msg("fetching DERPMap updates")
|
||||
|
||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { //nolint:contextcheck
|
||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
region, _ := h.DERPServer.GenerateRegion()
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
@@ -303,6 +316,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
||||
continue
|
||||
}
|
||||
|
||||
h.state.SetDERPMap(derpMap)
|
||||
|
||||
h.Change(change.DERPMap())
|
||||
@@ -311,6 +325,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
h.Change(change.ExtraRecords())
|
||||
@@ -339,7 +354,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
if !ok {
|
||||
return ctx, status.Errorf(
|
||||
codes.InvalidArgument,
|
||||
"Retrieving metadata is failed",
|
||||
"retrieving metadata",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -347,7 +362,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
if !ok {
|
||||
return ctx, status.Errorf(
|
||||
codes.Unauthenticated,
|
||||
"Authorization token is not supplied",
|
||||
"authorization token not supplied",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -362,7 +377,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
|
||||
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
||||
if err != nil {
|
||||
return ctx, status.Error(codes.Internal, "failed to validate token")
|
||||
return ctx, status.Error(codes.Internal, "validating token")
|
||||
}
|
||||
|
||||
if !valid {
|
||||
@@ -390,7 +405,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
|
||||
writeUnauthorized := func(statusCode int) {
|
||||
writer.WriteHeader(statusCode)
|
||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
|
||||
|
||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil { //nolint:noinlineerr
|
||||
log.Error().Err(err).Msg("writing HTTP response failed")
|
||||
}
|
||||
}
|
||||
@@ -401,6 +417,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
writeUnauthorized(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -412,6 +429,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("failed to validate token")
|
||||
writeUnauthorized(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -420,6 +438,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("invalid token")
|
||||
writeUnauthorized(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -431,7 +450,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
// and will remove it if it is not.
|
||||
func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
||||
// File does not exist, all fine
|
||||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
|
||||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { //nolint:noinlineerr
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -455,6 +474,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
||||
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
||||
Methods(http.MethodGet)
|
||||
@@ -484,8 +504,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
}
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
//
|
||||
//nolint:gocyclo // complex server startup function
|
||||
func (h *Headscale) Serve() error {
|
||||
var err error
|
||||
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
|
||||
if profilingEnabled {
|
||||
@@ -506,12 +529,13 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
versionInfo := types.GetVersionInfo()
|
||||
log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("Starting Headscale")
|
||||
log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("starting headscale")
|
||||
log.Info().
|
||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
@@ -526,7 +550,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get DERPMap: %w", err)
|
||||
return fmt.Errorf("getting DERPMap: %w", err)
|
||||
}
|
||||
|
||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
@@ -545,6 +569,7 @@ func (h *Headscale) Serve() error {
|
||||
// around between restarts, they will reconnect and the GC will
|
||||
// be cancelled.
|
||||
go h.ephemeralGC.Start()
|
||||
|
||||
ephmNodes := h.state.ListEphemeralNodes()
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
@@ -555,7 +580,9 @@ func (h *Headscale) Serve() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
||||
|
||||
go h.extraRecordMan.Run()
|
||||
defer h.extraRecordMan.Close()
|
||||
}
|
||||
@@ -564,6 +591,7 @@ func (h *Headscale) Serve() error {
|
||||
// records updates
|
||||
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
||||
defer scheduleCancel()
|
||||
|
||||
go h.scheduledTasks(scheduleCtx)
|
||||
|
||||
if zl.GlobalLevel() == zl.TraceLevel {
|
||||
@@ -576,6 +604,7 @@ func (h *Headscale) Serve() error {
|
||||
errorGroup := new(errgroup.Group)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -586,29 +615,30 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
err = h.ensureUnixSocketIsAbsent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove old socket file: %w", err)
|
||||
return fmt.Errorf("removing old socket file: %w", err)
|
||||
}
|
||||
|
||||
socketDir := filepath.Dir(h.cfg.UnixSocket)
|
||||
|
||||
err = util.EnsureDir(socketDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up unix socket: %w", err)
|
||||
}
|
||||
|
||||
socketListener, err := net.Listen("unix", h.cfg.UnixSocket)
|
||||
socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set up gRPC socket: %w", err)
|
||||
return fmt.Errorf("setting up gRPC socket: %w", err)
|
||||
}
|
||||
|
||||
// Change socket permissions
|
||||
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil {
|
||||
return fmt.Errorf("failed change permission of gRPC socket: %w", err)
|
||||
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("changing gRPC socket permission: %w", err)
|
||||
}
|
||||
|
||||
grpcGatewayMux := grpcRuntime.NewServeMux()
|
||||
|
||||
// Make the grpc-gateway connect to grpc over socket
|
||||
grpcGatewayConn, err := grpc.Dial(
|
||||
grpcGatewayConn, err := grpc.Dial( //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||
h.cfg.UnixSocket,
|
||||
[]grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
@@ -659,10 +689,13 @@ func (h *Headscale) Serve() error {
|
||||
// https://github.com/soheilhy/cmux/issues/68
|
||||
// https://github.com/soheilhy/cmux/issues/91
|
||||
|
||||
var grpcServer *grpc.Server
|
||||
var grpcListener net.Listener
|
||||
var (
|
||||
grpcServer *grpc.Server
|
||||
grpcListener net.Listener
|
||||
)
|
||||
|
||||
if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
|
||||
log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr)
|
||||
log.Info().Msgf("enabling remote gRPC at %s", h.cfg.GRPCAddr)
|
||||
|
||||
grpcOptions := []grpc.ServerOption{
|
||||
grpc.ChainUnaryInterceptor(
|
||||
@@ -685,9 +718,9 @@ func (h *Headscale) Serve() error {
|
||||
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
|
||||
reflection.Register(grpcServer)
|
||||
|
||||
grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
|
||||
grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
||||
return fmt.Errorf("binding to TCP address: %w", err)
|
||||
}
|
||||
|
||||
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||
@@ -715,14 +748,16 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
var httpListener net.Listener
|
||||
|
||||
if tlsConfig != nil {
|
||||
httpServer.TLSConfig = tlsConfig
|
||||
httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
|
||||
} else {
|
||||
httpListener, err = net.Listen("tcp", h.cfg.Addr)
|
||||
httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
||||
return fmt.Errorf("binding to TCP address: %w", err)
|
||||
}
|
||||
|
||||
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||||
@@ -738,7 +773,7 @@ func (h *Headscale) Serve() error {
|
||||
if h.cfg.MetricsAddr != "" {
|
||||
debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
||||
return fmt.Errorf("binding to TCP address: %w", err)
|
||||
}
|
||||
|
||||
debugHTTPServer = h.debugHTTPServer()
|
||||
@@ -751,19 +786,24 @@ func (h *Headscale) Serve() error {
|
||||
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
||||
}
|
||||
|
||||
|
||||
var tailsqlContext context.Context
|
||||
|
||||
if tailsqlEnabled {
|
||||
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
||||
log.Fatal().
|
||||
Str("type", h.cfg.Database.Type).
|
||||
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
||||
}
|
||||
|
||||
if tailsqlTSKey == "" {
|
||||
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
||||
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
||||
}
|
||||
|
||||
tailsqlContext = context.Background()
|
||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
|
||||
|
||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) //nolint:errcheck
|
||||
}
|
||||
|
||||
// Handle common process-killing signals so we can gracefully shut down:
|
||||
@@ -774,6 +814,7 @@ func (h *Headscale) Serve() error {
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT,
|
||||
syscall.SIGHUP)
|
||||
|
||||
sigFunc := func(c chan os.Signal) {
|
||||
// Wait for a SIGINT or SIGKILL:
|
||||
for {
|
||||
@@ -798,6 +839,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
|
||||
log.Info().
|
||||
Str("signal", sig.String()).
|
||||
Msg("Received signal to stop, shutting down gracefully")
|
||||
@@ -854,6 +896,7 @@ func (h *Headscale) Serve() error {
|
||||
if debugHTTPListener != nil {
|
||||
debugHTTPListener.Close()
|
||||
}
|
||||
|
||||
httpListener.Close()
|
||||
grpcGatewayConn.Close()
|
||||
|
||||
@@ -863,6 +906,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
// Close state connections
|
||||
info("closing state and database")
|
||||
|
||||
err = h.state.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to close state")
|
||||
@@ -875,6 +919,7 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errorGroup.Go(func() error {
|
||||
sigFunc(sigc)
|
||||
|
||||
@@ -886,6 +931,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
var err error
|
||||
|
||||
if h.cfg.TLS.LetsEncrypt.Hostname != "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Warn().
|
||||
@@ -918,7 +964,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
// Configuration via autocert with HTTP-01. This requires listening on
|
||||
// port 80 for the certificate validation in addition to the headscale
|
||||
// service, which can be configured to run on any other port.
|
||||
|
||||
server := &http.Server{
|
||||
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
||||
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
||||
@@ -940,13 +985,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
}
|
||||
} else if h.cfg.TLS.CertPath == "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
||||
log.Warn().Msg("listening without TLS but ServerURL does not start with http://")
|
||||
}
|
||||
|
||||
return nil, err
|
||||
} else {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
log.Warn().Msg("listening with TLS but ServerURL does not start with https://")
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
@@ -963,6 +1008,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
|
||||
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
err := util.EnsureDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ensuring private key directory: %w", err)
|
||||
@@ -970,21 +1016,22 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
|
||||
privateKey, err := os.ReadFile(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
log.Info().Str("path", path).Msg("No private key file at path, creating...")
|
||||
log.Info().Str("path", path).Msg("no private key file at path, creating...")
|
||||
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machineKeyStr, err := machineKey.MarshalText()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to convert private key to string for saving: %w",
|
||||
"converting private key to string for saving: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to save private key to disk at path %q: %w",
|
||||
"saving private key to disk at path %q: %w",
|
||||
path,
|
||||
err,
|
||||
)
|
||||
@@ -992,14 +1039,14 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
|
||||
return &machineKey, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key file: %w", err)
|
||||
return nil, fmt.Errorf("reading private key file: %w", err)
|
||||
}
|
||||
|
||||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||||
|
||||
var machineKey key.MachinePrivate
|
||||
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("parsing private key: %w", err)
|
||||
}
|
||||
|
||||
return &machineKey, nil
|
||||
@@ -1023,7 +1070,7 @@ type acmeLogger struct {
|
||||
func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := l.rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", req.URL.String()).Msg("ACME request failed")
|
||||
log.Error().Err(err).Str("url", req.URL.String()).Msg("acme request failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1031,7 +1078,7 @@ func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("ACME request returned error")
|
||||
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("acme request returned error")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
||||
@@ -16,12 +16,11 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type AuthProvider interface {
|
||||
RegisterHandler(http.ResponseWriter, *http.Request)
|
||||
AuthURL(types.RegistrationID) string
|
||||
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
||||
AuthURL(regID types.RegistrationID) string
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegister(
|
||||
@@ -42,8 +41,7 @@ func (h *Headscale) handleRegister(
|
||||
// This is a logout attempt (expiry in the past)
|
||||
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
EmbedObject(node).
|
||||
Bool("is_ephemeral", node.IsEphemeral()).
|
||||
Bool("has_authkey", node.AuthKey().Valid()).
|
||||
Msg("Found existing node for logout, calling handleLogout")
|
||||
@@ -52,6 +50,7 @@ func (h *Headscale) handleRegister(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling logout: %w", err)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
@@ -113,8 +112,7 @@ func (h *Headscale) handleRegister(
|
||||
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
if httpErr, ok := errors.AsType[HTTPError](err); ok {
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
@@ -133,7 +131,7 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
||||
// logout attempt from a node. If the node is not attempting to
|
||||
// logout attempt from a node. If the node is not attempting to.
|
||||
func (h *Headscale) handleLogout(
|
||||
node types.NodeView,
|
||||
req tailcfg.RegisterRequest,
|
||||
@@ -155,11 +153,12 @@ func (h *Headscale) handleLogout(
|
||||
// force the client to re-authenticate.
|
||||
// TODO(kradalby): I wonder if this is a path we ever hit?
|
||||
if node.IsExpired() {
|
||||
log.Trace().Str("node.name", node.Hostname()).
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
log.Trace().
|
||||
EmbedObject(node).
|
||||
Interface("reg.req", req).
|
||||
Bool("unexpected", true).
|
||||
Msg("Node key expired, forcing re-authentication")
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
@@ -182,8 +181,7 @@ func (h *Headscale) handleLogout(
|
||||
// Zero expiry is handled in handleRegister() before calling this function.
|
||||
if req.Expiry.Before(time.Now()) {
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
EmbedObject(node).
|
||||
Bool("is_ephemeral", node.IsEphemeral()).
|
||||
Bool("has_authkey", node.AuthKey().Valid()).
|
||||
Time("req.expiry", req.Expiry).
|
||||
@@ -191,8 +189,7 @@ func (h *Headscale) handleLogout(
|
||||
|
||||
if node.IsEphemeral() {
|
||||
log.Info().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
EmbedObject(node).
|
||||
Msg("Deleting ephemeral node during logout")
|
||||
|
||||
c, err := h.state.DeleteNode(node)
|
||||
@@ -209,8 +206,7 @@ func (h *Headscale) handleLogout(
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
EmbedObject(node).
|
||||
Msg("Node is not ephemeral, setting expiry instead of deleting")
|
||||
}
|
||||
|
||||
@@ -279,6 +275,7 @@ func (h *Headscale) waitForFollowup(
|
||||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node.View()), nil
|
||||
}
|
||||
}
|
||||
@@ -316,7 +313,7 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
LastSeen: new(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -324,7 +321,7 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
nodeToRegister.Node.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
log.Info().Msgf("New followup node registration using key: %s", newRegID)
|
||||
log.Info().Msgf("new followup node registration using key: %s", newRegID)
|
||||
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
@@ -344,8 +341,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
var perr types.PAKError
|
||||
if errors.As(err, &perr) {
|
||||
|
||||
if perr, ok := errors.AsType[types.PAKError](err); ok {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||
}
|
||||
|
||||
@@ -355,7 +352,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||
if !node.Valid() {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no node to return when ephemeral deleted
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
@@ -397,8 +394,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
Caller().
|
||||
Interface("reg.resp", resp).
|
||||
Interface("reg.req", req).
|
||||
Str("node.name", node.Hostname()).
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
EmbedObject(node).
|
||||
Msg("RegisterResponse")
|
||||
|
||||
return resp, nil
|
||||
@@ -435,6 +431,7 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with empty hostname, generated default")
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
@@ -443,7 +440,7 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
LastSeen: new(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -456,7 +453,7 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
nodeToRegister,
|
||||
)
|
||||
|
||||
log.Info().Msgf("Starting node registration using key: %s", registrationId)
|
||||
log.Info().Msgf("starting node registration using key: %s", registrationId)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
AuthURL: h.authProvider.AuthURL(registrationId),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,7 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.88": 125,
|
||||
"v1.90": 130,
|
||||
"v1.92": 131,
|
||||
"v1.94": 131,
|
||||
}
|
||||
|
||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||
|
||||
@@ -9,10 +9,9 @@ var tailscaleLatestMajorMinorTests = []struct {
|
||||
stripV bool
|
||||
expected []string
|
||||
}{
|
||||
{3, false, []string{"v1.88", "v1.90", "v1.92"}},
|
||||
{2, true, []string{"1.90", "1.92"}},
|
||||
{3, false, []string{"v1.90", "v1.92", "v1.94"}},
|
||||
{2, true, []string{"1.92", "1.94"}},
|
||||
{10, true, []string{
|
||||
"1.74",
|
||||
"1.76",
|
||||
"1.78",
|
||||
"1.80",
|
||||
@@ -22,6 +21,7 @@ var tailscaleLatestMajorMinorTests = []struct {
|
||||
"1.88",
|
||||
"1.90",
|
||||
"1.92",
|
||||
"1.94",
|
||||
}},
|
||||
{0, false, nil},
|
||||
}
|
||||
|
||||
@@ -77,8 +77,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil {
|
||||
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||
return "", nil, fmt.Errorf("saving API key to database: %w", err)
|
||||
}
|
||||
|
||||
return keyStr, &key, nil
|
||||
@@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
keys := []types.APIKey{}
|
||||
if err := hsdb.DB.Find(&keys).Error; err != nil {
|
||||
|
||||
err := hsdb.DB.Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||
|
||||
// ExpireAPIKey marks a ApiKey as expired.
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
@@ -53,6 +52,8 @@ type HSDatabase struct {
|
||||
|
||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||
// It accepts the full configuration to allow migrations access to policy settings.
|
||||
//
|
||||
//nolint:gocyclo // complex database initialization with many migrations
|
||||
func NewHeadscaleDatabase(
|
||||
cfg *types.Config,
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
||||
@@ -76,7 +77,7 @@ func NewHeadscaleDatabase(
|
||||
ID: "202501221827",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Remove any invalid routes associated with a node that does not exist.
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -84,14 +85,14 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
|
||||
// Remove any invalid routes without a node_id.
|
||||
if tx.Migrator().HasTable(&types.Route{}) {
|
||||
if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
err := tx.Exec("delete from routes where node_id is null").Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := tx.AutoMigrate(&types.Route{})
|
||||
err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Route: %w", err)
|
||||
}
|
||||
@@ -109,6 +110,7 @@ func NewHeadscaleDatabase(
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
|
||||
}
|
||||
|
||||
err = tx.AutoMigrate(&types.Node{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Node: %w", err)
|
||||
@@ -155,7 +157,8 @@ AND auth_key_id NOT IN (
|
||||
|
||||
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||
|
||||
var routes []types.Route
|
||||
var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
|
||||
err = tx.Find(&routes).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching routes: %w", err)
|
||||
@@ -168,10 +171,10 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for nodeID, routes := range nodeRoutes {
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.SortFunc(routes, netip.Prefix.Compare)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
data, _ := json.Marshal(routes)
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||
if err != nil {
|
||||
@@ -180,7 +183,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
// Drop the old table.
|
||||
_ = tx.Migrator().DropTable(&types.Route{})
|
||||
_ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
|
||||
return nil
|
||||
},
|
||||
@@ -245,21 +248,24 @@ AND auth_key_id NOT IN (
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Only run on SQLite
|
||||
if cfg.Database.Type != types.DatabaseSqlite {
|
||||
log.Info().Msg("Skipping schema migration on non-SQLite database")
|
||||
log.Info().Msg("skipping schema migration on non-SQLite database")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info().Msg("Starting schema recreation with table renaming")
|
||||
log.Info().Msg("starting schema recreation with table renaming")
|
||||
|
||||
// Rename existing tables to _old versions
|
||||
tablesToRename := []string{"users", "pre_auth_keys", "api_keys", "nodes", "policies"}
|
||||
|
||||
// Check if routes table exists and drop it (should have been migrated already)
|
||||
var routesExists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
||||
if err == nil && routesExists {
|
||||
log.Info().Msg("Dropping leftover routes table")
|
||||
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
|
||||
log.Info().Msg("dropping leftover routes table")
|
||||
|
||||
err := tx.Exec("DROP TABLE routes").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("dropping routes table: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -281,6 +287,7 @@ AND auth_key_id NOT IN (
|
||||
for _, table := range tablesToRename {
|
||||
// Check if table exists before renaming
|
||||
var exists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
||||
@@ -291,7 +298,8 @@ AND auth_key_id NOT IN (
|
||||
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
|
||||
// Rename current table to _old
|
||||
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
|
||||
}
|
||||
}
|
||||
@@ -365,7 +373,8 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, createSQL := range tableCreationSQL {
|
||||
if err := tx.Exec(createSQL).Error; err != nil {
|
||||
err := tx.Exec(createSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating new table: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -394,7 +403,8 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, copySQL := range dataCopySQL {
|
||||
if err := tx.Exec(copySQL).Error; err != nil {
|
||||
err := tx.Exec(copySQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("copying data: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -417,19 +427,21 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, indexSQL := range indexes {
|
||||
if err := tx.Exec(indexSQL).Error; err != nil {
|
||||
err := tx.Exec(indexSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Drop old tables only after everything succeeds
|
||||
for _, table := range tablesToRename {
|
||||
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
|
||||
log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded")
|
||||
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
if err != nil {
|
||||
log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded")
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Msg("Schema recreation completed successfully")
|
||||
log.Info().Msg("schema recreation completed successfully")
|
||||
|
||||
return nil
|
||||
},
|
||||
@@ -595,12 +607,12 @@ AND auth_key_id NOT IN (
|
||||
// 1. Load policy from file or database based on configuration
|
||||
policyData, err := PolicyBytes(tx, cfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
log.Warn().Err(err).Msg("failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(policyData) == 0 {
|
||||
log.Info().Msg("No policy found, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
log.Info().Msg("no policy found, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -618,7 +630,7 @@ AND auth_key_id NOT IN (
|
||||
// 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.)
|
||||
polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
log.Warn().Err(err).Msg("failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -652,8 +664,7 @@ AND auth_key_id NOT IN (
|
||||
if len(validatedTags) == 0 {
|
||||
if len(rejectedTags) > 0 {
|
||||
log.Debug().
|
||||
Uint64("node.id", uint64(node.ID)).
|
||||
Str("node.name", node.Hostname).
|
||||
EmbedObject(node).
|
||||
Strs("rejected_tags", rejectedTags).
|
||||
Msg("RequestTags rejected during migration (not authorized)")
|
||||
}
|
||||
@@ -676,8 +687,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Uint64("node.id", uint64(node.ID)).
|
||||
Str("node.name", node.Hostname).
|
||||
EmbedObject(node).
|
||||
Strs("validated_tags", validatedTags).
|
||||
Strs("rejected_tags", rejectedTags).
|
||||
Strs("existing_tags", existingTags).
|
||||
@@ -762,6 +772,7 @@ AND auth_key_id NOT IN (
|
||||
|
||||
// or else it blocks...
|
||||
sqlConn.SetMaxIdleConns(maxIdleConns)
|
||||
|
||||
sqlConn.SetMaxOpenConns(maxOpenConns)
|
||||
defer sqlConn.SetMaxIdleConns(1)
|
||||
defer sqlConn.SetMaxOpenConns(1)
|
||||
@@ -779,7 +790,7 @@ AND auth_key_id NOT IN (
|
||||
},
|
||||
}
|
||||
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("validating schema: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -805,6 +816,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
switch cfg.Type {
|
||||
case types.DatabaseSqlite:
|
||||
dir := filepath.Dir(cfg.Sqlite.Path)
|
||||
|
||||
err := util.EnsureDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
||||
@@ -858,7 +870,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
Str("path", dbString).
|
||||
Msg("Opening database")
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
}
|
||||
@@ -913,7 +925,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
|
||||
// Get the current foreign key status
|
||||
var fkOriginallyEnabled int
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("checking foreign key status: %w", err)
|
||||
}
|
||||
|
||||
@@ -937,33 +949,36 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
}
|
||||
|
||||
for _, migrationID := range migrationIDs {
|
||||
log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration")
|
||||
log.Trace().Caller().Str("migration_id", migrationID).Msg("running migration")
|
||||
needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
|
||||
|
||||
if needsFKDisabled {
|
||||
// Disable foreign keys for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
} else {
|
||||
// Ensure foreign keys are enabled for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run up to this specific migration (will only run the next pending migration)
|
||||
if err := migrations.MigrateTo(migrationID); err != nil {
|
||||
err := migrations.MigrateTo(migrationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("restoring foreign keys: %w", err)
|
||||
}
|
||||
|
||||
// Run the rest of the migrations
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
if err := migrations.Migrate(); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -981,16 +996,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var violation constraintViolation
|
||||
if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil {
|
||||
|
||||
err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
violatedConstraints = append(violatedConstraints, violation)
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
if err := rows.Err(); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
if len(violatedConstraints) > 0 {
|
||||
for _, violation := range violatedConstraints {
|
||||
@@ -1005,7 +1026,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
}
|
||||
} else {
|
||||
// PostgreSQL can run all migrations in one block - no foreign key issues
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
err := migrations.Migrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1016,6 +1038,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
sqlDB, err := hsdb.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1031,7 +1054,7 @@ func (hsdb *HSDatabase) Close() error {
|
||||
}
|
||||
|
||||
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
||||
db.Exec("VACUUM")
|
||||
db.Exec("VACUUM") //nolint:errcheck,noctx
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
@@ -1040,12 +1063,14 @@ func (hsdb *HSDatabase) Close() error {
|
||||
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
||||
rx := hsdb.DB.Begin()
|
||||
defer rx.Rollback()
|
||||
|
||||
return fn(rx)
|
||||
}
|
||||
|
||||
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
rx := db.Begin()
|
||||
defer rx.Rollback()
|
||||
|
||||
ret, err := fn(rx)
|
||||
if err != nil {
|
||||
var no T
|
||||
@@ -1058,7 +1083,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
tx := hsdb.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
if err := fn(tx); err != nil {
|
||||
|
||||
err := fn(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1068,6 +1095,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
||||
tx := db.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
ret, err := fn(tx)
|
||||
if err != nil {
|
||||
var no T
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
|
||||
// Verify api_keys data preservation
|
||||
var apiKeyCount int
|
||||
|
||||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||||
@@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(schemaContent))
|
||||
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
func requireConstraintFailed(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||
}
|
||||
@@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "no-duplicate-username-if-no-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
_, err := CreateUser(db, types.User{Name: "user1"})
|
||||
require.NoError(t, err)
|
||||
_, err = CreateUser(db, types.User{Name: "user1"})
|
||||
@@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-username-and-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
@@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
@@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-cli-then-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-oidc-then-cli",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Name: "user1",
|
||||
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||||
@@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Construct the pg_restore command
|
||||
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
|
||||
// Set the output streams
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
||||
// skip already-applied migrations and only run new ones.
|
||||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemas, err := os.ReadDir("testdata/sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Basic deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var deletionWg sync.WaitGroup
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
deletionWg sync.WaitGroup
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
deletionWg.Done()
|
||||
}
|
||||
@@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
go gc.Start()
|
||||
|
||||
// Schedule several nodes for deletion with short expiry
|
||||
const expiry = fifty
|
||||
const numNodes = 100
|
||||
const (
|
||||
expiry = fifty
|
||||
numNodes = 100
|
||||
)
|
||||
|
||||
// Set up wait group for expected deletions
|
||||
|
||||
deletionWg.Add(numNodes)
|
||||
|
||||
for i := 1; i <= numNodes; i++ {
|
||||
gc.Schedule(types.NodeID(i), expiry)
|
||||
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
|
||||
// Wait for all scheduled deletions to complete
|
||||
@@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
|
||||
// Schedule and immediately cancel to test that part of the code
|
||||
for i := numNodes + 1; i <= numNodes*2; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test
|
||||
gc.Schedule(nodeID, time.Hour)
|
||||
gc.Cancel(nodeID)
|
||||
}
|
||||
@@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
@@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
|
||||
// Start GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
const shortExpiry = fifty
|
||||
const longExpiry = 1 * time.Hour
|
||||
const (
|
||||
shortExpiry = fifty
|
||||
longExpiry = 1 * time.Hour
|
||||
)
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
@@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
}
|
||||
|
||||
// Start the GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
const expiry = fifty
|
||||
|
||||
// Schedule node for deletion
|
||||
@@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
||||
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
@@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
nodeDeleted := make(chan struct{})
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
close(nodeDeleted) // Signal that deletion happened
|
||||
}
|
||||
@@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
// Use a WaitGroup to ensure the GC has started
|
||||
var startWg sync.WaitGroup
|
||||
startWg.Add(1)
|
||||
|
||||
go func() {
|
||||
startWg.Done() // Signal that the goroutine has started
|
||||
gc.Start()
|
||||
}()
|
||||
|
||||
startWg.Wait() // Wait for the GC to start
|
||||
|
||||
// Close GC right away
|
||||
@@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
|
||||
// Check no node was deleted
|
||||
deleteMutex.Lock()
|
||||
|
||||
nodesDeleted := len(deletedIDs)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
||||
|
||||
@@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
go gc.Start()
|
||||
|
||||
// Number of concurrent scheduling goroutines
|
||||
const numSchedulers = 10
|
||||
const nodesPerScheduler = 50
|
||||
const (
|
||||
numSchedulers = 10
|
||||
nodesPerScheduler = 50
|
||||
)
|
||||
|
||||
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
||||
|
||||
@@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
case <-stopScheduling:
|
||||
return
|
||||
default:
|
||||
nodeID := types.NodeID(baseNodeID + j + 1)
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
atomic.AddInt64(&scheduledCount, 1)
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
|
||||
@@ -17,7 +17,11 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||
var (
|
||||
errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||
errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix")
|
||||
errIPAllocatorNil = errors.New("ip allocator was nil")
|
||||
)
|
||||
|
||||
// IPAllocator is a singleton responsible for allocating
|
||||
// IP addresses for nodes and making sure the same
|
||||
@@ -62,8 +66,10 @@ func NewIPAllocator(
|
||||
strategy: strategy,
|
||||
}
|
||||
|
||||
var v4s []sql.NullString
|
||||
var v6s []sql.NullString
|
||||
var (
|
||||
v4s []sql.NullString
|
||||
v6s []sql.NullString
|
||||
)
|
||||
|
||||
if db != nil {
|
||||
err := db.Read(func(rx *gorm.DB) error {
|
||||
@@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var ret4 *netip.Addr
|
||||
var ret6 *netip.Addr
|
||||
var (
|
||||
err error
|
||||
ret4 *netip.Addr
|
||||
ret6 *netip.Addr
|
||||
)
|
||||
|
||||
if i.prefix4 != nil {
|
||||
ret4, err = i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
}
|
||||
|
||||
i.prev4 = *ret4
|
||||
}
|
||||
|
||||
@@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
}
|
||||
|
||||
i.prev6 = *ret6
|
||||
}
|
||||
|
||||
@@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.
|
||||
}
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
||||
var err error
|
||||
var ip netip.Addr
|
||||
var (
|
||||
err error
|
||||
ip netip.Addr
|
||||
)
|
||||
|
||||
switch i.strategy {
|
||||
case types.IPAllocationStrategySequential:
|
||||
@@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
||||
|
||||
if !pfx.Contains(ip) {
|
||||
return netip.Addr{}, fmt.Errorf(
|
||||
"generated ip(%s) not in prefix(%s)",
|
||||
"%w: ip(%s) not in prefix(%s)",
|
||||
errGeneratedIPNotInPrefix,
|
||||
ip.String(),
|
||||
pfx.String(),
|
||||
)
|
||||
@@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool {
|
||||
// If a prefix type has been removed (IPv4 or IPv6), it
|
||||
// will remove the IPs in that family from the node.
|
||||
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
var err error
|
||||
var ret []string
|
||||
var (
|
||||
err error
|
||||
ret []string
|
||||
)
|
||||
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
if i == nil {
|
||||
return errors.New("backfilling IPs: ip allocator was nil")
|
||||
return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil)
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("starting to backfill IPs")
|
||||
@@ -283,18 +299,19 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database")
|
||||
log.Trace().Caller().EmbedObject(node).Msg("ip backfill check started because node found in database")
|
||||
|
||||
changed := false
|
||||
// IPv4 prefix is set, but node ip is missing, alloc
|
||||
if i.prefix4 != nil && node.IPv4 == nil {
|
||||
ret4, err := i.nextLocked(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err)
|
||||
return fmt.Errorf("allocating IPv4 for node(%d): %w", node.ID, err)
|
||||
}
|
||||
|
||||
node.IPv4 = ret4
|
||||
changed = true
|
||||
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
@@ -302,11 +319,12 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
if i.prefix6 != nil && node.IPv6 == nil {
|
||||
ret6, err := i.nextLocked(i.prev6, i.prefix6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err)
|
||||
return fmt.Errorf("allocating IPv6 for node(%d): %w", node.ID, err)
|
||||
}
|
||||
|
||||
node.IPv6 = ret6
|
||||
changed = true
|
||||
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var mpp = func(pref string) *netip.Prefix {
|
||||
@@ -21,9 +20,7 @@ var mpp = func(pref string) *netip.Prefix {
|
||||
return &p
|
||||
}
|
||||
|
||||
var na = func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
var na = netip.MustParseAddr
|
||||
|
||||
var nap = func(pref string) *netip.Addr {
|
||||
n := na(pref)
|
||||
@@ -158,8 +155,10 @@ func TestIPAllocatorSequential(t *testing.T) {
|
||||
types.IPAllocationStrategySequential,
|
||||
)
|
||||
|
||||
var got4s []netip.Addr
|
||||
var got6s []netip.Addr
|
||||
var (
|
||||
got4s []netip.Addr
|
||||
got6s []netip.Addr
|
||||
)
|
||||
|
||||
for range tt.getCount {
|
||||
got4, got6, err := alloc.Next()
|
||||
@@ -175,6 +174,7 @@ func TestIPAllocatorSequential(t *testing.T) {
|
||||
got6s = append(got6s, *got6)
|
||||
}
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -288,6 +288,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||
fullNodeP := func(i int) *types.Node {
|
||||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||
|
||||
return &types.Node{
|
||||
IPv4: nap(v4),
|
||||
IPv6: nap(v6),
|
||||
@@ -484,12 +485,13 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer db.Close()
|
||||
|
||||
alloc, err := NewIPAllocator(
|
||||
db,
|
||||
ptr.To(tsaddr.CGNATRange()),
|
||||
ptr.To(tsaddr.TailscaleULARange()),
|
||||
new(tsaddr.CGNATRange()),
|
||||
new(tsaddr.TailscaleULARange()),
|
||||
types.IPAllocationStrategySequential,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -497,17 +499,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
}
|
||||
|
||||
// Validate that we do not give out 100.100.100.100
|
||||
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
|
||||
nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
||||
|
||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
|
||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
||||
|
||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
|
||||
nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange()))
|
||||
t.Logf("chrome: %s", nextChrome.String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
||||
|
||||
@@ -16,18 +16,24 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
NodeGivenNameHashLength = 8
|
||||
NodeGivenNameTrimSize = 2
|
||||
|
||||
// defaultTestNodePrefix is the default hostname prefix for nodes created in tests.
|
||||
defaultTestNodePrefix = "testnode"
|
||||
)
|
||||
|
||||
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
var (
|
||||
@@ -51,12 +57,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID)
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
|
||||
err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where("id <> ?", nodeID).
|
||||
Where(peerIDs).Find(&nodes).Error; err != nil {
|
||||
Where(peerIDs).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return types.Nodes{}, err
|
||||
}
|
||||
|
||||
@@ -75,11 +83,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error)
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
|
||||
err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
||||
Where(nodeIDs).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -89,7 +99,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
|
||||
|
||||
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -207,6 +219,7 @@ func SetTags(
|
||||
|
||||
slices.Sort(tags)
|
||||
tags = slices.Compact(tags)
|
||||
|
||||
b, err := json.Marshal(tags)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -220,7 +233,7 @@ func SetTags(
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
// SetApprovedRoutes takes a Node struct pointer and updates the approved routes.
|
||||
func SetApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
@@ -228,7 +241,8 @@ func SetApprovedRoutes(
|
||||
) error {
|
||||
if len(routes) == 0 {
|
||||
// if no routes are provided, we remove all
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
|
||||
err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing approved routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -251,7 +265,7 @@ func SetApprovedRoutes(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("updating approved routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -277,22 +291,25 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
||||
func RenameNode(tx *gorm.DB,
|
||||
nodeID types.NodeID, newName string,
|
||||
) error {
|
||||
if err := util.ValidateHostname(newName); err != nil {
|
||||
err := util.ValidateHostname(newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Check if the new name is unique
|
||||
var count int64
|
||||
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking name uniqueness: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("name is not unique")
|
||||
return ErrNodeNameNotUnique
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename node in the database: %w", err)
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("renaming node in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -323,7 +340,8 @@ func DeleteNode(tx *gorm.DB,
|
||||
node *types.Node,
|
||||
) error {
|
||||
// Unscoped causes the node to be fully removed from the database.
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||
err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -337,9 +355,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
||||
nodeID types.NodeID,
|
||||
) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
|
||||
err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -352,19 +372,19 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
}
|
||||
|
||||
logEvent := log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString())
|
||||
Str(zf.NodeHostname, node.Hostname).
|
||||
Str(zf.MachineKey, node.MachineKey.ShortString()).
|
||||
Str(zf.NodeKey, node.NodeKey.ShortString())
|
||||
|
||||
if node.User != nil {
|
||||
logEvent = logEvent.Str("user", node.User.Username())
|
||||
logEvent = logEvent.Str(zf.UserName, node.User.Username())
|
||||
} else if node.UserID != nil {
|
||||
logEvent = logEvent.Uint("user_id", *node.UserID)
|
||||
logEvent = logEvent.Uint(zf.UserID, *node.UserID)
|
||||
} else {
|
||||
logEvent = logEvent.Str("user", "none")
|
||||
logEvent = logEvent.Str(zf.UserName, "none")
|
||||
}
|
||||
|
||||
logEvent.Msg("Registering test node")
|
||||
logEvent.Msg("registering test node")
|
||||
|
||||
// If the a new node is registered with the same machine key, to the same user,
|
||||
// update the existing node.
|
||||
@@ -379,6 +399,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
if ipv4 == nil {
|
||||
ipv4 = oldNode.IPv4
|
||||
}
|
||||
|
||||
if ipv6 == nil {
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
@@ -388,16 +409,17 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
// so we store the node.Expire and node.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if node.IPv4 != nil || node.IPv6 != nil {
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
|
||||
err := tx.Save(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registering existing node in database: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Username()).
|
||||
Str(zf.NodeHostname, node.Hostname).
|
||||
Str(zf.MachineKey, node.MachineKey.ShortString()).
|
||||
Str(zf.NodeKey, node.NodeKey.ShortString()).
|
||||
Str(zf.UserName, node.User.Username()).
|
||||
Msg("Test node authorized again")
|
||||
|
||||
return &node, nil
|
||||
@@ -407,29 +429,30 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
node.IPv6 = ipv6
|
||||
|
||||
var err error
|
||||
|
||||
node.Hostname, err = util.NormaliseHostname(node.Hostname)
|
||||
if err != nil {
|
||||
newHostname := util.InvalidString()
|
||||
log.Info().Err(err).Str("invalid-hostname", node.Hostname).Str("new-hostname", newHostname).Msgf("Invalid hostname, replacing")
|
||||
log.Info().Err(err).Str(zf.InvalidHostname, node.Hostname).Str(zf.NewHostname, newHostname).Msgf("invalid hostname, replacing")
|
||||
node.Hostname = newHostname
|
||||
}
|
||||
|
||||
if node.GivenName == "" {
|
||||
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||
return nil, fmt.Errorf("ensuring unique given name: %w", err)
|
||||
}
|
||||
|
||||
node.GivenName = givenName
|
||||
}
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("saving node to database: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Str(zf.NodeHostname, node.Hostname).
|
||||
Msg("Test node registered with the database")
|
||||
|
||||
return &node, nil
|
||||
@@ -491,8 +514,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
|
||||
func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
|
||||
|
||||
err := tx.
|
||||
Where("given_name = ?", name).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -646,7 +671,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
||||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
nodeName := defaultTestNodePrefix
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
@@ -668,7 +693,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
||||
Hostname: nodeName,
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
|
||||
err = hsdb.DB.Save(node).Error
|
||||
@@ -694,9 +719,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
||||
}
|
||||
|
||||
var registeredNode *types.Node
|
||||
|
||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
|
||||
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
@@ -715,7 +743,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname
|
||||
panic("CreateNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
prefix := defaultTestNodePrefix
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
@@ -738,7 +766,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
prefix := defaultTestNodePrefix
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestGetNode(t *testing.T) {
|
||||
@@ -115,7 +114,7 @@ func TestExpireNode(t *testing.T) {
|
||||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
db.DB.Save(node)
|
||||
@@ -159,7 +158,7 @@ func TestSetTags(t *testing.T) {
|
||||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
|
||||
trx := db.DB.Save(node)
|
||||
@@ -187,6 +186,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
suppliedName string
|
||||
randomSuffix bool
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -443,7 +443,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
}
|
||||
|
||||
err = adb.DB.Save(&node).Error
|
||||
@@ -460,17 +460,17 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
Tags: []string{"tag:exit"},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
@@ -498,6 +498,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -509,6 +510,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -520,6 +522,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
want := []types.NodeID{1, 3}
|
||||
got := []types.NodeID{}
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
deletionCount := make(chan struct{}, 10)
|
||||
@@ -527,6 +530,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
got = append(got, ni)
|
||||
|
||||
deletionCount <- struct{}{}
|
||||
@@ -576,8 +580,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
var got []types.NodeID
|
||||
var mu sync.Mutex
|
||||
var (
|
||||
got []types.NodeID
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
want := 1000
|
||||
|
||||
@@ -589,6 +595,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
runtime.Gosched()
|
||||
|
||||
got = append(got, ni)
|
||||
|
||||
atomic.AddInt64(&deletedCount, 1)
|
||||
@@ -616,9 +623,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateRandomNumber(t *testing.T, max int64) int64 {
|
||||
//nolint:unused
|
||||
func generateRandomNumber(t *testing.T, maxVal int64) int64 {
|
||||
t.Helper()
|
||||
maxB := big.NewInt(max)
|
||||
|
||||
maxB := big.NewInt(maxVal)
|
||||
|
||||
n, err := rand.Int(rand.Reader, maxB)
|
||||
if err != nil {
|
||||
t.Fatalf("getting random number: %s", err)
|
||||
@@ -649,7 +659,7 @@ func TestListEphemeralNodes(t *testing.T) {
|
||||
Hostname: "test",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
|
||||
nodeEph := types.Node{
|
||||
@@ -659,7 +669,7 @@ func TestListEphemeralNodes(t *testing.T) {
|
||||
Hostname: "ephemeral",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pakEph.ID),
|
||||
AuthKeyID: new(pakEph.ID),
|
||||
}
|
||||
|
||||
err = db.DB.Save(&node).Error
|
||||
@@ -722,7 +732,7 @@ func TestNodeNaming(t *testing.T) {
|
||||
nodeInvalidHostname := types.Node{
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
UserID: &user2.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
@@ -746,12 +756,15 @@ func TestNodeNaming(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
_, _ = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -810,25 +823,25 @@ func TestNodeNaming(t *testing.T) {
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "test")
|
||||
})
|
||||
assert.ErrorContains(t, err, "name is not unique")
|
||||
require.ErrorContains(t, err, "name is not unique")
|
||||
|
||||
// Rename invalid chars
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑")
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename too short
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[3].ID, "a")
|
||||
})
|
||||
assert.ErrorContains(t, err, "at least 2 characters")
|
||||
require.ErrorContains(t, err, "at least 2 characters")
|
||||
|
||||
// Rename with emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename with only emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
@@ -896,12 +909,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "chinese_chars_with_dash_rejected",
|
||||
newName: "server-北京-01",
|
||||
newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "chinese_only_rejected",
|
||||
newName: "我的电脑",
|
||||
newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
@@ -911,7 +924,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "mixed_chinese_emoji_rejected",
|
||||
newName: "测试💻机器",
|
||||
newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
@@ -1000,6 +1013,7 @@ func TestListPeers(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
@@ -1085,6 +1099,7 @@ func TestListNodes(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
||||
@@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
|
||||
Data: policy,
|
||||
}
|
||||
|
||||
if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil {
|
||||
err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -138,8 +138,8 @@ func CreatePreAuthKey(
|
||||
Hash: hash, // Store hash
|
||||
}
|
||||
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
||||
if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("creating key in database: %w", err)
|
||||
}
|
||||
|
||||
return &types.PreAuthKeyNew{
|
||||
@@ -155,9 +155,7 @@ func CreatePreAuthKey(
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeys(rx)
|
||||
})
|
||||
return Read(hsdb.DB, ListPreAuthKeys)
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns all PreAuthKeys in the database.
|
||||
@@ -296,7 +294,7 @@ func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
Where("auth_key_id = ?", id).
|
||||
Update("auth_key_id", nil).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
|
||||
return fmt.Errorf("clearing auth_key_id on nodes: %w", err)
|
||||
}
|
||||
|
||||
// Then delete the pre-auth key
|
||||
@@ -325,14 +323,15 @@ func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
|
||||
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
err := tx.Model(k).Update("used", true).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||
return fmt.Errorf("updating key used status in database: %w", err)
|
||||
}
|
||||
|
||||
k.Used = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
// ExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestCreatePreAuthKey(t *testing.T) {
|
||||
@@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) {
|
||||
test: func(t *testing.T, db *HSDatabase) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil)
|
||||
_, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil)
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
@@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
|
||||
Hostname: "testest",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(key.ID),
|
||||
AuthKeyID: new(key.ID),
|
||||
}
|
||||
db.DB.Save(&node)
|
||||
|
||||
|
||||
@@ -362,7 +362,8 @@ func (c *Config) Validate() error {
|
||||
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
|
||||
// compatible with modernc.org/sqlite driver.
|
||||
func (c *Config) ToURL() (string, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
err := c.Validate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
@@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) {
|
||||
if c.BusyTimeout > 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
||||
}
|
||||
|
||||
if c.JournalMode != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
||||
}
|
||||
|
||||
if c.AutoVacuum != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
||||
}
|
||||
|
||||
if c.WALAutocheckpoint >= 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
||||
}
|
||||
|
||||
if c.Synchronous != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
||||
}
|
||||
|
||||
if c.ForeignKeys {
|
||||
pragmas = append(pragmas, "foreign_keys=ON")
|
||||
}
|
||||
|
||||
@@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
|
||||
t.Errorf("Config.ToURL() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
@@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
|
||||
Path: "",
|
||||
BusyTimeout: -1,
|
||||
}
|
||||
|
||||
_, err := config.ToURL()
|
||||
if err == nil {
|
||||
t.Error("Config.ToURL() with invalid config should return error")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqliteconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
ctx := context.Background()
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
@@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
||||
for pragma, expectedValue := range tt.expected {
|
||||
t.Run("pragma_"+pragma, func(t *testing.T) {
|
||||
var actualValue any
|
||||
|
||||
query := "PRAGMA " + pragma
|
||||
err := db.QueryRow(query).Scan(&actualValue)
|
||||
|
||||
err := db.QueryRowContext(ctx, query).Scan(&actualValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query %s: %v", query, err)
|
||||
}
|
||||
@@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test tables with foreign key relationship
|
||||
schema := `
|
||||
CREATE TABLE parent (
|
||||
@@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
_, err = db.ExecContext(ctx, schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create schema: %v", err)
|
||||
}
|
||||
|
||||
// Insert parent record
|
||||
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert parent: %v", err)
|
||||
}
|
||||
|
||||
// Test 1: Valid foreign key should work
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Valid foreign key insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Test 2: Invalid foreign key should fail
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
@@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test 3: Deleting referenced parent should fail
|
||||
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
|
||||
_, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
@@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
var actualMode string
|
||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
||||
|
||||
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||
}
|
||||
|
||||
@@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL {
|
||||
t.Helper()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
srv, err := postgrestest.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(srv.Cleanup)
|
||||
|
||||
u, err := srv.CreateDatabase(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("created local postgres: %s", u)
|
||||
pu, _ := url.Parse(u)
|
||||
|
||||
|
||||
@@ -3,12 +3,19 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnmarshalTextValue = errors.New("unmarshalling text value")
|
||||
errUnsupportedType = errors.New("unsupported type")
|
||||
errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
|
||||
)
|
||||
|
||||
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
|
||||
|
||||
@@ -24,7 +31,7 @@ func maybeInstantiatePtr(rv reflect.Value) {
|
||||
}
|
||||
|
||||
func decodingError(name string, err error) error {
|
||||
return fmt.Errorf("error decoding to %s: %w", name, err)
|
||||
return fmt.Errorf("decoding to %s: %w", name, err)
|
||||
}
|
||||
|
||||
// TextSerialiser implements the Serialiser interface for fields that
|
||||
@@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
|
||||
return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue)
|
||||
}
|
||||
|
||||
if isTextUnmarshaler(fieldValue) {
|
||||
maybeInstantiatePtr(fieldValue)
|
||||
f := fieldValue.MethodByName("UnmarshalText")
|
||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||
|
||||
ret := f.Call(args)
|
||||
if !ret[0].IsNil() {
|
||||
return decodingError(field.Name, ret[0].Interface().(error))
|
||||
if err, ok := ret[0].Interface().(error); ok {
|
||||
return decodingError(field.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If the underlying field is to a pointer type, we need to
|
||||
@@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||||
return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
||||
// always comparable, particularly when reflection is involved:
|
||||
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer
|
||||
}
|
||||
|
||||
b, err := v.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
||||
|
||||
return string(b), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||||
return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
|
||||
ErrUserNotUnique = errors.New("expected exactly one user")
|
||||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
@@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
||||
if err := util.ValidateHostname(user.Name); err != nil {
|
||||
err := util.ValidateHostname(user.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
|
||||
err = tx.Create(&user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
@@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
return ErrUserStillHasNodes
|
||||
}
|
||||
@@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
err = DestroyPreAuthKey(tx, key.ID)
|
||||
if err != nil {
|
||||
@@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
||||
// not exist or if another User exists with the new name.
|
||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
var err error
|
||||
|
||||
oldUser, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = util.ValidateHostname(newName); err != nil {
|
||||
|
||||
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||
// ListUsers gets all the existing users.
|
||||
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
if len(where) > 1 {
|
||||
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
||||
return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
|
||||
}
|
||||
|
||||
var user *types.User
|
||||
@@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
}
|
||||
|
||||
users := []types.User{}
|
||||
if err := tx.Where(user).Find(&users).Error; err != nil {
|
||||
|
||||
err := tx.Where(user).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||
}
|
||||
|
||||
if len(users) != 1 {
|
||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||
return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
|
||||
}
|
||||
|
||||
return &users[0], nil
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestCreateAndDestroyUser(t *testing.T) {
|
||||
@@ -79,7 +78,7 @@ func TestDestroyUserErrors(t *testing.T) {
|
||||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
require.NoError(t, trx.Error)
|
||||
|
||||
@@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
overview := h.state.DebugOverviewJSON()
|
||||
|
||||
overviewJSON, err := json.MarshalIndent(overview, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(overviewJSON)
|
||||
_, _ = w.Write(overviewJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
overview := h.state.DebugOverview()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(overview))
|
||||
_, _ = w.Write([]byte(overview))
|
||||
}
|
||||
}))
|
||||
|
||||
// Configuration endpoint
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config := h.state.DebugConfig()
|
||||
|
||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(configJSON)
|
||||
_, _ = w.Write(configJSON)
|
||||
}))
|
||||
|
||||
// Policy endpoint
|
||||
@@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policy))
|
||||
_, _ = w.Write([]byte(policy))
|
||||
}))
|
||||
|
||||
// Filter rules endpoint
|
||||
@@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(filterJSON)
|
||||
_, _ = w.Write(filterJSON)
|
||||
}))
|
||||
|
||||
// SSH policies endpoint
|
||||
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sshPolicies := h.state.DebugSSHPolicies()
|
||||
|
||||
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(sshJSON)
|
||||
_, _ = w.Write(sshJSON)
|
||||
}))
|
||||
|
||||
// DERP map endpoint
|
||||
@@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
derpInfo := h.state.DebugDERPJSON()
|
||||
|
||||
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(derpJSON)
|
||||
_, _ = w.Write(derpJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
derpInfo := h.state.DebugDERPMap()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(derpInfo))
|
||||
_, _ = w.Write([]byte(derpInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
nodeStoreNodes := h.state.DebugNodeStoreJSON()
|
||||
|
||||
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(nodeStoreJSON)
|
||||
_, _ = w.Write(nodeStoreJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
nodeStoreInfo := h.state.DebugNodeStore()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(nodeStoreInfo))
|
||||
_, _ = w.Write([]byte(nodeStoreInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
// Registration cache endpoint
|
||||
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cacheInfo := h.state.DebugRegistrationCache()
|
||||
|
||||
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(cacheJSON)
|
||||
_, _ = w.Write(cacheJSON)
|
||||
}))
|
||||
|
||||
// Routes endpoint
|
||||
@@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
routes := h.state.DebugRoutes()
|
||||
|
||||
routesJSON, err := json.MarshalIndent(routes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(routesJSON)
|
||||
_, _ = w.Write(routesJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
routes := h.state.DebugRoutesString()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(routes))
|
||||
_, _ = w.Write([]byte(routes))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
policyManagerInfo := h.state.DebugPolicyManagerJSON()
|
||||
|
||||
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(policyManagerJSON)
|
||||
_, _ = w.Write(policyManagerJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
policyManagerInfo := h.state.DebugPolicyManager()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policyManagerInfo))
|
||||
_, _ = w.Write([]byte(policyManagerInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if res == nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
_, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(resJSON)
|
||||
_, _ = w.Write(resJSON)
|
||||
}))
|
||||
|
||||
// Batcher endpoint
|
||||
@@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(batcherJSON)
|
||||
_, _ = w.Write(batcherJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
batcherInfo := h.debugBatcher()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(batcherInfo))
|
||||
_, _ = w.Write([]byte(batcherInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
|
||||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
@@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
|
||||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,11 +28,14 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
|
||||
return nil, err
|
||||
}
|
||||
defer derpFile.Close()
|
||||
|
||||
var derpMap tailcfg.DERPMap
|
||||
|
||||
b, err := io.ReadAll(derpFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(b, &derpMap)
|
||||
|
||||
return &derpMap, err
|
||||
@@ -57,12 +60,14 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var derpMap tailcfg.DERPMap
|
||||
|
||||
err = json.Unmarshal(body, &derpMap)
|
||||
|
||||
return &derpMap, err
|
||||
@@ -134,6 +139,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
|
||||
for id := range dm.Regions {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
@@ -160,16 +166,18 @@ func derpRandom() *rand.Rand {
|
||||
|
||||
derpRandomOnce.Do(func() {
|
||||
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
|
||||
rnd := rand.New(rand.NewSource(0))
|
||||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
|
||||
rnd := rand.New(rand.NewSource(0)) //nolint:gosec // weak random is fine for DERP scrambling
|
||||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) //nolint:gosec // safe conversion
|
||||
derpRandomInst = rnd
|
||||
})
|
||||
|
||||
return derpRandomInst
|
||||
}
|
||||
|
||||
func resetDerpRandomForTesting() {
|
||||
derpRandomMu.Lock()
|
||||
defer derpRandomMu.Unlock()
|
||||
|
||||
derpRandomOnce = sync.Once{}
|
||||
derpRandomInst = nil
|
||||
}
|
||||
|
||||
@@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
viper.Set("dns.base_domain", tt.baseDomain)
|
||||
|
||||
defer viper.Reset()
|
||||
|
||||
resetDerpRandomForTesting()
|
||||
|
||||
testMap := tt.derpMap.View().AsStruct()
|
||||
|
||||
@@ -54,7 +54,7 @@ func NewDERPServer(
|
||||
derpKey key.NodePrivate,
|
||||
cfg *types.DERPConfig,
|
||||
) (*DERPServer, error) {
|
||||
log.Trace().Caller().Msg("Creating new embedded DERP server")
|
||||
log.Trace().Caller().Msg("creating new embedded DERP server")
|
||||
server := derpserver.New(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains
|
||||
|
||||
if cfg.ServerVerifyClients {
|
||||
@@ -75,9 +75,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
var host string
|
||||
var port int
|
||||
var portStr string
|
||||
|
||||
var (
|
||||
host string
|
||||
port int
|
||||
portStr string
|
||||
)
|
||||
|
||||
// Extract hostname and port from URL
|
||||
host, portStr, err = net.SplitHostPort(serverURL.Host)
|
||||
@@ -98,13 +101,13 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
|
||||
// If debug flag is set, resolve hostname to IP address
|
||||
if debugUseDERPIP {
|
||||
ips, err := net.LookupIP(host)
|
||||
ips, err := new(net.Resolver).LookupIPAddr(context.Background(), host)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host)
|
||||
log.Error().Caller().Err(err).Msgf("failed to resolve DERP hostname %s to IP, using hostname", host)
|
||||
} else if len(ips) > 0 {
|
||||
// Use the first IP address
|
||||
ipStr := ips[0].String()
|
||||
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr)
|
||||
ipStr := ips[0].IP.String()
|
||||
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: resolved %s to %s", host, ipStr)
|
||||
host = ipStr
|
||||
}
|
||||
}
|
||||
@@ -130,14 +133,16 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
|
||||
portSTUN, err := strconv.Atoi(portSTUNStr)
|
||||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
|
||||
localDERPregion.Nodes[0].STUNPort = portSTUN
|
||||
|
||||
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
|
||||
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0])
|
||||
log.Info().Caller().Msgf("derp region: %+v", localDERPregion)
|
||||
log.Info().Caller().Msgf("derp nodes[0]: %+v", localDERPregion.Nodes[0])
|
||||
|
||||
return localDERPregion, nil
|
||||
}
|
||||
@@ -155,8 +160,10 @@ func (d *DERPServer) DERPHandler(
|
||||
Caller().
|
||||
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain")
|
||||
writer.WriteHeader(http.StatusUpgradeRequired)
|
||||
|
||||
_, err := writer.Write([]byte("DERP requires connection upgrade"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -206,6 +213,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
||||
return
|
||||
}
|
||||
defer websocketConn.Close(websocket.StatusInternalError, "closing")
|
||||
|
||||
if websocketConn.Subprotocol() != "derp" {
|
||||
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
|
||||
|
||||
@@ -222,9 +230,10 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
|
||||
hijacker, ok := writer.(http.Hijacker)
|
||||
if !ok {
|
||||
log.Error().Caller().Msg("DERP requires Hijacker interface from Gin")
|
||||
log.Error().Caller().Msg("derp requires Hijacker interface from Gin")
|
||||
writer.Header().Set("Content-Type", "text/plain")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, err := writer.Write([]byte("HTTP does not support general TCP support"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -238,9 +247,10 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
|
||||
netConn, conn, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("Hijack failed")
|
||||
log.Error().Caller().Err(err).Msgf("hijack failed")
|
||||
writer.Header().Set("Content-Type", "text/plain")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, err = writer.Write([]byte("HTTP does not support general TCP support"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -251,7 +261,8 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
|
||||
return
|
||||
}
|
||||
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
|
||||
|
||||
log.Trace().Caller().Msgf("hijacked connection from %v", req.RemoteAddr)
|
||||
|
||||
if !fastStart {
|
||||
pubKey := d.key.Public()
|
||||
@@ -280,6 +291,7 @@ func DERPProbeHandler(
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
|
||||
_, err := writer.Write([]byte("bogus probe method"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -309,9 +321,11 @@ func DERPBootstrapDNSHandler(
|
||||
|
||||
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var resolver net.Resolver
|
||||
for _, region := range derpMap.Regions().All() {
|
||||
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
|
||||
|
||||
for _, region := range derpMap.Regions().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
|
||||
for _, node := range region.Nodes().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
|
||||
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
|
||||
if err != nil {
|
||||
log.Trace().
|
||||
@@ -321,11 +335,14 @@ func DERPBootstrapDNSHandler(
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
dnsEntries[node.HostName()] = addrs
|
||||
}
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
err := json.NewEncoder(writer).Encode(dnsEntries)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -338,33 +355,37 @@ func DERPBootstrapDNSHandler(
|
||||
|
||||
// ServeSTUN starts a STUN server on the configured addr.
|
||||
func (d *DERPServer) ServeSTUN() {
|
||||
packetConn, err := net.ListenPacket("udp", d.cfg.STUNAddr)
|
||||
packetConn, err := new(net.ListenConfig).ListenPacket(context.Background(), "udp", d.cfg.STUNAddr)
|
||||
if err != nil {
|
||||
log.Fatal().Msgf("failed to open STUN listener: %v", err)
|
||||
}
|
||||
log.Info().Msgf("STUN server started at %s", packetConn.LocalAddr())
|
||||
|
||||
log.Info().Msgf("stun server started at %s", packetConn.LocalAddr())
|
||||
|
||||
udpConn, ok := packetConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
log.Fatal().Msg("STUN listener is not a UDP listener")
|
||||
log.Fatal().Msg("stun listener is not a UDP listener")
|
||||
}
|
||||
|
||||
serverSTUNListener(context.Background(), udpConn)
|
||||
}
|
||||
|
||||
func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
||||
var buf [64 << 10]byte
|
||||
var (
|
||||
buf [64 << 10]byte
|
||||
bytesRead int
|
||||
udpAddr *net.UDPAddr
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:])
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Error().Caller().Err(err).Msgf("STUN ReadFrom")
|
||||
|
||||
log.Error().Caller().Err(err).Msgf("stun ReadFrom")
|
||||
|
||||
// Rate limit error logging - wait before retrying, but respect context cancellation
|
||||
select {
|
||||
@@ -375,25 +396,29 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
||||
|
||||
continue
|
||||
}
|
||||
log.Trace().Caller().Msgf("STUN request from %v", udpAddr)
|
||||
|
||||
log.Trace().Caller().Msgf("stun request from %v", udpAddr)
|
||||
|
||||
pkt := buf[:bytesRead]
|
||||
if !stun.Is(pkt) {
|
||||
log.Trace().Caller().Msgf("UDP packet is not STUN")
|
||||
log.Trace().Caller().Msgf("udp packet is not stun")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
txid, err := stun.ParseBindingRequest(pkt)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("STUN parse error")
|
||||
log.Trace().Caller().Err(err).Msgf("stun parse error")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(udpAddr.IP)
|
||||
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port)))
|
||||
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) //nolint:gosec // port is always <=65535
|
||||
|
||||
_, err = packetConn.WriteTo(res, udpAddr)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("Issue writing to UDP")
|
||||
log.Trace().Caller().Err(err).Msgf("issue writing to UDP")
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -412,8 +437,10 @@ type DERPVerifyTransport struct {
|
||||
|
||||
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := t.handleVerifyRequest(req, buf); err != nil {
|
||||
log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ")
|
||||
|
||||
err := t.handleVerifyRequest(req, buf)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msg("failed to handle client verify request")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -15,6 +16,9 @@ import (
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
// ErrPathIsDirectory is returned when a directory path is provided where a file is expected.
|
||||
var ErrPathIsDirectory = errors.New("path is a directory, only file is supported")
|
||||
|
||||
type ExtraRecordsMan struct {
|
||||
mu sync.RWMutex
|
||||
records set.Set[tailcfg.DNSRecord]
|
||||
@@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) {
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
return nil, fmt.Errorf("path is a directory, only file is supported: %s", path)
|
||||
return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path)
|
||||
}
|
||||
|
||||
records, hash, err := readExtraRecordsFromPath(path)
|
||||
@@ -85,19 +89,22 @@ func (e *ExtraRecordsMan) Run() {
|
||||
log.Error().Caller().Msgf("file watcher event channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Op {
|
||||
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
|
||||
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
|
||||
|
||||
if event.Name != e.path {
|
||||
continue
|
||||
}
|
||||
|
||||
e.updateRecords()
|
||||
|
||||
// If a file is removed or renamed, fsnotify will loose track of it
|
||||
// and not watch it. We will therefore attempt to re-add it with a backoff.
|
||||
case fsnotify.Remove, fsnotify.Rename:
|
||||
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
|
||||
if _, err := os.Stat(e.path); err != nil {
|
||||
if _, err := os.Stat(e.path); err != nil { //nolint:noinlineerr
|
||||
return struct{}{}, err
|
||||
}
|
||||
|
||||
@@ -123,6 +130,7 @@ func (e *ExtraRecordsMan) Run() {
|
||||
log.Error().Caller().Msgf("file watcher error channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
|
||||
}
|
||||
}
|
||||
@@ -165,6 +173,7 @@ func (e *ExtraRecordsMan) updateRecords() {
|
||||
e.hashes[e.path] = newHash
|
||||
|
||||
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
|
||||
|
||||
e.updateCh <- e.records.Slice()
|
||||
}
|
||||
|
||||
@@ -183,6 +192,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
|
||||
}
|
||||
|
||||
var records []tailcfg.DNSRecord
|
||||
|
||||
err = json.Unmarshal(b, &records)
|
||||
if err != nil {
|
||||
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
)
|
||||
|
||||
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
|
||||
@@ -54,7 +55,7 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
}
|
||||
user, policyChanged, err := api.h.state.CreateUser(newUser)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
return nil, status.Errorf(codes.Internal, "creating user: %s", err)
|
||||
}
|
||||
|
||||
// CreateUser returns a policy change response if the user creation affected policy.
|
||||
@@ -235,16 +236,16 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
// Generate ephemeral registration key for tracking this registration flow in logs
|
||||
registrationKey, err := util.GenerateRegistrationKey()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to generate registration key")
|
||||
log.Warn().Err(err).Msg("failed to generate registration key")
|
||||
registrationKey = "" // Continue without key if generation fails
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("user", request.GetUser()).
|
||||
Str("registration_id", request.GetKey()).
|
||||
Str("registration_key", registrationKey).
|
||||
Msg("Registering node")
|
||||
Str(zf.UserName, request.GetUser()).
|
||||
Str(zf.RegistrationID, request.GetKey()).
|
||||
Str(zf.RegistrationKey, registrationKey).
|
||||
Msg("registering node")
|
||||
|
||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
||||
if err != nil {
|
||||
@@ -264,17 +265,16 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("registration_key", registrationKey).
|
||||
Str(zf.RegistrationKey, registrationKey).
|
||||
Err(err).
|
||||
Msg("Failed to register node")
|
||||
Msg("failed to register node")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("registration_key", registrationKey).
|
||||
Str("node_id", fmt.Sprintf("%d", node.ID())).
|
||||
Str("hostname", node.Hostname()).
|
||||
Msg("Node registered successfully")
|
||||
Str(zf.RegistrationKey, registrationKey).
|
||||
EmbedObject(node).
|
||||
Msg("node registered successfully")
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
// dependency here.
|
||||
@@ -355,9 +355,9 @@ func (api headscaleV1APIServer) SetTags(
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname()).
|
||||
EmbedObject(node).
|
||||
Strs("tags", request.GetTags()).
|
||||
Msg("Changing tags of node")
|
||||
Msg("changing tags of node")
|
||||
|
||||
return &v1.SetTagsResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@@ -368,7 +368,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
) (*v1.SetApprovedRoutesResponse, error) {
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", request.GetNodeId()).
|
||||
Uint64(zf.NodeID, request.GetNodeId()).
|
||||
Strs("requestedRoutes", request.GetRoutes()).
|
||||
Msg("gRPC SetApprovedRoutes called")
|
||||
|
||||
@@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
newApproved = append(newApproved, prefix)
|
||||
}
|
||||
}
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
slices.SortFunc(newApproved, netip.Prefix.Compare)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||
@@ -406,7 +406,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
EmbedObject(node).
|
||||
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
||||
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
||||
@@ -423,7 +423,7 @@ func validateTag(tag string) error {
|
||||
return errors.New("tag should be lowercase")
|
||||
}
|
||||
if len(strings.Fields(tag)) > 1 {
|
||||
return errors.New("tag should not contains space")
|
||||
return errors.New("tags must not contain spaces")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -466,8 +466,8 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname()).
|
||||
Time("expiry", *node.AsStruct().Expiry).
|
||||
EmbedObject(node).
|
||||
Time(zf.ExpiresAt, *node.AsStruct().Expiry).
|
||||
Msg("node expired")
|
||||
|
||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||
@@ -487,8 +487,8 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node", node.Hostname()).
|
||||
Str("new_name", request.GetNewName()).
|
||||
EmbedObject(node).
|
||||
Str(zf.NewName, request.GetNewName()).
|
||||
Msg("node renamed")
|
||||
|
||||
return &v1.RenameNodeResponse{Node: node.Proto()}, nil
|
||||
@@ -546,7 +546,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
|
||||
ctx context.Context,
|
||||
request *v1.BackfillNodeIPsRequest,
|
||||
) (*v1.BackfillNodeIPsResponse, error) {
|
||||
log.Trace().Caller().Msg("Backfill called")
|
||||
log.Trace().Caller().Msg("backfill called")
|
||||
|
||||
if !request.Confirmed {
|
||||
return nil, errors.New("not confirmed, aborting")
|
||||
@@ -817,13 +817,13 @@ func (api headscaleV1APIServer) Health(
|
||||
response := &v1.HealthResponse{}
|
||||
|
||||
if err := api.h.state.PingDB(ctx); err != nil {
|
||||
healthErr = fmt.Errorf("database ping failed: %w", err)
|
||||
healthErr = fmt.Errorf("pinging database: %w", err)
|
||||
} else {
|
||||
response.DatabaseConnectivity = true
|
||||
}
|
||||
|
||||
if healthErr != nil {
|
||||
log.Error().Err(healthErr).Msg("Health check failed")
|
||||
log.Error().Err(healthErr).Msg("health check failed")
|
||||
}
|
||||
|
||||
return response, healthErr
|
||||
|
||||
@@ -17,6 +17,7 @@ func Test_validateTag(t *testing.T) {
|
||||
type args struct {
|
||||
tag string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -45,7 +46,8 @@ func Test_validateTag(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr {
|
||||
err := validateTag(tt.args.tag)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// The CapabilityVersion is used by Tailscale clients to indicate
|
||||
// NoiseCapabilityVersion is used by Tailscale clients to indicate
|
||||
// their codebase version. Tailscale clients can communicate over TS2021
|
||||
// from CapabilityVersion 28, but we only have good support for it
|
||||
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
|
||||
@@ -36,8 +36,7 @@ const (
|
||||
|
||||
// httpError logs an error and sends an HTTP error response with the given.
|
||||
func httpError(w http.ResponseWriter, err error) {
|
||||
var herr HTTPError
|
||||
if errors.As(err, &herr) {
|
||||
if herr, ok := errors.AsType[HTTPError](err); ok {
|
||||
http.Error(w, herr.Msg, herr.Code)
|
||||
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
|
||||
} else {
|
||||
@@ -56,7 +55,7 @@ type HTTPError struct {
|
||||
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
|
||||
func (e HTTPError) Unwrap() error { return e.Err }
|
||||
|
||||
// Error returns an HTTPError containing the given information.
|
||||
// NewHTTPError returns an HTTPError containing the given information.
|
||||
func NewHTTPError(code int, msg string, err error) HTTPError {
|
||||
return HTTPError{Code: code, Msg: msg, Err: err}
|
||||
}
|
||||
@@ -64,7 +63,7 @@ func NewHTTPError(code int, msg string, err error) HTTPError {
|
||||
var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil)
|
||||
|
||||
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
||||
"machines registered with CLI does not support expire",
|
||||
"machines registered with CLI do not support expiry",
|
||||
)
|
||||
|
||||
func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) {
|
||||
@@ -76,7 +75,7 @@ func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error
|
||||
|
||||
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
|
||||
if err != nil {
|
||||
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err))
|
||||
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("parsing capability version: %w", err))
|
||||
}
|
||||
|
||||
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
||||
@@ -88,12 +87,12 @@ func (h *Headscale) handleVerifyRequest(
|
||||
) error {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read request body: %w", err)
|
||||
return fmt.Errorf("reading request body: %w", err)
|
||||
}
|
||||
|
||||
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
||||
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
|
||||
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { //nolint:noinlineerr
|
||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("parsing DERP client request: %w", err))
|
||||
}
|
||||
|
||||
nodes := h.state.ListNodes()
|
||||
@@ -155,7 +154,11 @@ func (h *Headscale) KeyHandler(
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(writer).Encode(resp)
|
||||
|
||||
err := json.NewEncoder(writer).Encode(resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to encode public key response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -180,8 +183,12 @@ func (h *Headscale) HealthHandler(
|
||||
res.Status = "fail"
|
||||
}
|
||||
|
||||
json.NewEncoder(writer).Encode(res)
|
||||
encErr := json.NewEncoder(writer).Encode(res)
|
||||
if encErr != nil {
|
||||
log.Error().Err(encErr).Msg("failed to encode health response")
|
||||
}
|
||||
}
|
||||
|
||||
err := h.state.PingDB(req.Context())
|
||||
if err != nil {
|
||||
respond(err)
|
||||
@@ -218,6 +225,7 @@ func (h *Headscale) VersionHandler(
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
versionInfo := types.GetVersionInfo()
|
||||
|
||||
err := json.NewEncoder(writer).Encode(versionInfo)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -244,7 +252,7 @@ func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
|
||||
registrationId.String())
|
||||
}
|
||||
|
||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||
// RegisterHandler shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register/:registration_id.
|
||||
//
|
||||
// This is not part of the Tailscale control API, as we could send whatever URL
|
||||
@@ -267,7 +275,11 @@ func (a *AuthProviderWeb) RegisterHandler(
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||
|
||||
_, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to write register response")
|
||||
}
|
||||
}
|
||||
|
||||
func FaviconHandler(writer http.ResponseWriter, req *http.Request) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
@@ -15,6 +16,14 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Mapper errors.
|
||||
var (
|
||||
ErrInvalidNodeID = errors.New("invalid nodeID")
|
||||
ErrMapperNil = errors.New("mapper is nil")
|
||||
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
||||
ErrNodeNotFoundMapper = errors.New("node not found")
|
||||
)
|
||||
|
||||
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "headscale",
|
||||
Name: "mapresponse_generated_total",
|
||||
@@ -80,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
||||
}
|
||||
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
||||
}
|
||||
|
||||
// Handle self-only responses
|
||||
@@ -135,12 +144,12 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
||||
if nc == nil {
|
||||
return errors.New("nodeConnection is nil")
|
||||
return ErrNodeConnectionNil
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
|
||||
log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started")
|
||||
|
||||
data, err := generateMapResponse(nc, mapper, r)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package mapper
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
@@ -10,13 +11,20 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var errConnectionClosed = errors.New("connection channel already closed")
|
||||
// LockFreeBatcher errors.
|
||||
var (
|
||||
errConnectionClosed = errors.New("connection channel already closed")
|
||||
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
|
||||
ErrBatcherShuttingDown = errors.New("batcher shutting down")
|
||||
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
@@ -48,6 +56,7 @@ type LockFreeBatcher struct {
|
||||
// and notifies other nodes that this node has come online.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||
addNodeStart := time.Now()
|
||||
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
||||
|
||||
// Generate connection ID
|
||||
connID := generateConnectionID()
|
||||
@@ -76,9 +85,10 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
// Use the worker pool for controlled concurrency instead of direct generation
|
||||
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
||||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
nlog.Error().Err(err).Msg("initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
|
||||
return fmt.Errorf("generating initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Use a blocking send with timeout for initial map since the channel should be ready
|
||||
@@ -86,12 +96,13 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
select {
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
||||
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
||||
case <-time.After(5 * time.Second): //nolint:mnd
|
||||
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
|
||||
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
|
||||
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
@@ -100,9 +111,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
// Node will automatically receive updates through the normal flow
|
||||
// The initial full map already contains all current state
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
|
||||
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection established in batcher because AddNode completed successfully")
|
||||
Msg("node connection established in batcher")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -112,31 +123,34 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
||||
// Reports if the node still has active connections after removal.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
||||
|
||||
nodeConn, exists := b.nodes.Load(id)
|
||||
if !exists {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher")
|
||||
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove specific connection
|
||||
removed := nodeConn.removeConnectionByChannel(c)
|
||||
if !removed {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
|
||||
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if node has any remaining active connections
|
||||
if nodeConn.hasActiveConnections() {
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
||||
nlog.Debug().Caller().
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection removed but keeping online because other connections remain")
|
||||
Msg("node connection removed but keeping online, other connections remain")
|
||||
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// No active connections - keep the node entry alive for rapid reconnections
|
||||
// The node will get a fresh full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, new(time.Now()))
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -196,11 +210,13 @@ func (b *LockFreeBatcher) doWork() {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-b.workCh:
|
||||
if !ok {
|
||||
log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID)
|
||||
wlog.Debug().Msg("worker channel closing, shutting down")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -212,29 +228,29 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
var err error
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("worker.id", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("reason", w.c.Reason).
|
||||
wlog.Error().Err(result.err).
|
||||
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||
Str(zf.Reason, w.c.Reason).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
} else if result.mapResponse != nil {
|
||||
// Update peer tracking for synchronous responses too
|
||||
nc.updateSentPeers(result.mapResponse)
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("worker.id", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
wlog.Error().Err(result.err).
|
||||
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||
Msg("node not found for synchronous work")
|
||||
}
|
||||
|
||||
@@ -257,15 +273,14 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(err).
|
||||
Int("worker.id", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("reason", w.c.Reason).
|
||||
wlog.Error().Err(err).
|
||||
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||
Str(zf.Reason, w.c.Reason).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
}
|
||||
case <-b.done:
|
||||
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker")
|
||||
wlog.Debug().Msg("batcher shutting down, exiting worker")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -310,8 +325,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
|
||||
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
|
||||
b.totalNodes.Add(-1)
|
||||
log.Debug().
|
||||
Uint64("node.id", removedID.Uint64()).
|
||||
Msg("Removed deleted node from batcher")
|
||||
Uint64(zf.NodeID, removedID.Uint64()).
|
||||
Msg("removed deleted node from batcher")
|
||||
}
|
||||
|
||||
b.connected.Delete(removedID)
|
||||
@@ -398,14 +413,15 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up the identified nodes
|
||||
for _, nodeID := range nodesToCleanup {
|
||||
log.Info().Uint64("node.id", nodeID.Uint64()).
|
||||
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
|
||||
Dur("offline_duration", cleanupThreshold).
|
||||
Msg("Cleaning up node that has been offline for too long")
|
||||
Msg("cleaning up node that has been offline for too long")
|
||||
|
||||
b.nodes.Delete(nodeID)
|
||||
b.connected.Delete(nodeID)
|
||||
@@ -413,8 +429,8 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
}
|
||||
|
||||
if len(nodesToCleanup) > 0 {
|
||||
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
|
||||
Msg("Completed cleanup of long-offline nodes")
|
||||
log.Info().Int(zf.CleanedNodes, len(nodesToCleanup)).
|
||||
Msg("completed cleanup of long-offline nodes")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -450,6 +466,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -465,6 +482,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -484,7 +502,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,6 +520,7 @@ type connectionEntry struct {
|
||||
type multiChannelNodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
log zerolog.Logger
|
||||
|
||||
mutex sync.RWMutex
|
||||
connections []*connectionEntry
|
||||
@@ -518,8 +537,9 @@ type multiChannelNodeConn struct {
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
_, _ = rand.Read(bytes)
|
||||
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
@@ -528,6 +548,7 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
||||
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -546,18 +567,21 @@ func (mc *multiChannelNodeConn) close() {
|
||||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Dur("mutex_wait_time", mutexWaitDur).
|
||||
Msg("Successfully added connection after mutex wait")
|
||||
Msg("successfully added connection after mutex wait")
|
||||
}
|
||||
|
||||
// removeConnectionByChannel removes a connection by matching channel pointer.
|
||||
@@ -569,12 +593,14 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
|
||||
if entry.c == c {
|
||||
// Remove this connection
|
||||
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("Successfully removed connection")
|
||||
Msg("successfully removed connection")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -606,36 +632,41 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
if len(mc.connections) == 0 {
|
||||
// During rapid reconnection, nodes may temporarily have no active connections
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
mc.log.Debug().Caller().
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
mc.log.Debug().Caller().
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
|
||||
successCount := 0
|
||||
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
for i, conn := range mc.connections {
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
err := conn.send(data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
|
||||
failedConnections = append(failedConnections, i)
|
||||
log.Warn().Err(err).
|
||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Msg("send: successfully sent to connection")
|
||||
}
|
||||
}
|
||||
@@ -643,15 +674,15 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
// Remove failed connections (in reverse order to maintain indices)
|
||||
for i := len(failedConnections) - 1; i >= 0; i-- {
|
||||
idx := failedConnections[i]
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Str("conn.id", mc.connections[idx].id).
|
||||
mc.log.Debug().Caller().
|
||||
Str(zf.ConnID, mc.connections[idx].id).
|
||||
Msg("send: removing failed connection")
|
||||
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
|
||||
}
|
||||
|
||||
mc.updateCount.Add(1)
|
||||
|
||||
log.Debug().Uint64("node.id", mc.id.Uint64()).
|
||||
mc.log.Debug().
|
||||
Int("successful_sends", successCount).
|
||||
Int("failed_connections", len(failedConnections)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
@@ -688,7 +719,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -798,6 +829,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -812,6 +844,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ type batcherTestCase struct {
|
||||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
|
||||
state *state.State
|
||||
}
|
||||
|
||||
@@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
||||
}
|
||||
|
||||
// Finally remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return t.Batcher.RemoveNode(id, c)
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
@@ -129,8 +125,6 @@ const (
|
||||
SMALL_BUFFER_SIZE = 3
|
||||
TINY_BUFFER_SIZE = 1 // For maximum contention
|
||||
LARGE_BUFFER_SIZE = 200
|
||||
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
// TestData contains all test entities created for a test scenario.
|
||||
@@ -241,8 +235,8 @@ func setupBatcherWithTestData(
|
||||
}
|
||||
|
||||
derpMap, err := derp.GetDERPMap(cfg.DERP)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, derpMap)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, derpMap)
|
||||
|
||||
state.SetDERPMap(derpMap)
|
||||
|
||||
@@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
|
||||
}
|
||||
|
||||
// getStats returns a copy of the statistics for a node.
|
||||
//
|
||||
//nolint:unused
|
||||
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
@@ -386,16 +382,14 @@ type UpdateInfo struct {
|
||||
}
|
||||
|
||||
// parseUpdateAndAnalyze parses an update and returns detailed information.
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
|
||||
info := UpdateInfo{
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
|
||||
return UpdateInfo{
|
||||
PeerCount: len(resp.Peers),
|
||||
PatchCount: len(resp.PeersChangedPatch),
|
||||
IsFull: len(resp.Peers) > 0,
|
||||
IsPatch: len(resp.PeersChangedPatch) > 0,
|
||||
IsDERP: resp.DERPMap != nil,
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// start begins consuming updates from the node's channel and tracking stats.
|
||||
@@ -417,7 +411,8 @@ func (n *node) start() {
|
||||
atomic.AddInt64(&n.updateCount, 1)
|
||||
|
||||
// Parse update and track detailed stats
|
||||
if info, err := parseUpdateAndAnalyze(data); err == nil {
|
||||
info := parseUpdateAndAnalyze(data)
|
||||
{
|
||||
// Track update types
|
||||
if info.IsFull {
|
||||
atomic.AddInt64(&n.fullCount, 1)
|
||||
@@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Wait for connection to be established
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
@@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
@@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
connectedCount := 0
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
@@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
||||
|
||||
t.Logf("✅ All nodes achieved full connectivity!")
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
|
||||
// Disconnect all nodes
|
||||
@@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
tn := testData.Nodes[0]
|
||||
tn2 := testData.Nodes[1]
|
||||
tn := &testData.Nodes[0]
|
||||
tn2 := &testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
@@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Drain any initial messages from first node
|
||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
drainChannelTimeout(tn.ch, 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
@@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
count++
|
||||
// Optional: add debug output if needed
|
||||
_ = data
|
||||
case <-ch:
|
||||
// Drain message
|
||||
case <-timer.C:
|
||||
return
|
||||
}
|
||||
@@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
@@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
// even when real node updates are being processed, ensuring no race conditions
|
||||
// occur during channel replacement with actual workload.
|
||||
func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
// Create test environment with real database and nodes
|
||||
@@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
testNode := &testData.Nodes[0]
|
||||
|
||||
var (
|
||||
channelIssues int
|
||||
@@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Go(func() {
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Add real work during connection chaos
|
||||
@@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
wg.Go(func() {
|
||||
runtime.Gosched() // Yield to introduce timing variability
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
|
||||
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Remove second connection
|
||||
@@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
testNode := &testData.Nodes[0]
|
||||
|
||||
var (
|
||||
panics int
|
||||
@@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
@@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
for range i % 3 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Yield to allow workers to process and close channels
|
||||
@@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
// real node data. The test validates that stable clients continue to function
|
||||
// normally and receive proper updates despite the connection churn from other clients,
|
||||
// ensuring system stability under concurrent load.
|
||||
//
|
||||
//nolint:gocyclo // complex concurrent test scenario
|
||||
func TestBatcherConcurrentClients(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent client test in short mode")
|
||||
@@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
|
||||
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
@@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, reason := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
@@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety
|
||||
for i := range numCycles {
|
||||
for _, node := range churningNodes {
|
||||
for j := range churningNodes {
|
||||
node := &churningNodes[j]
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// Connect churning node
|
||||
@@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
@@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
@@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
for range i % 5 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
@@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
// Node-specific changes using real nodes
|
||||
node := allNodes[i%len(allNodes)]
|
||||
node := &allNodes[i%len(allNodes)]
|
||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||
testExpiry := time.Now().Add(24 * time.Hour)
|
||||
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
|
||||
@@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
|
||||
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
|
||||
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
if stats, exists := allStats[node.n.ID]; exists {
|
||||
stableUpdateCount += stats.TotalUpdates
|
||||
t.Logf("Stable node %d: %d updates",
|
||||
@@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range churningNodes {
|
||||
for i := range churningNodes {
|
||||
node := &churningNodes[i]
|
||||
if stats, exists := allStats[node.n.ID]; exists {
|
||||
churningUpdateCount += stats.TotalUpdates
|
||||
}
|
||||
@@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify all stable clients are still functional
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Errorf("Stable node %d lost connection during racing", node.n.ID)
|
||||
}
|
||||
@@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// It validates that the system remains stable with no deadlocks, panics, or
|
||||
// missed updates under sustained high load. The test uses real node data to
|
||||
// generate authentic update scenarios and tracks comprehensive statistics.
|
||||
//
|
||||
//nolint:gocyclo,thelper // complex scalability test scenario
|
||||
func XTestBatcherScalability(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping scalability test in short mode")
|
||||
@@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
description string
|
||||
}
|
||||
|
||||
var testCases []testCase
|
||||
testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes))
|
||||
|
||||
// Generate all combinations of the test matrix
|
||||
for _, nodeCount := range nodes {
|
||||
@@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
@@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
}
|
||||
|
||||
// Connection/disconnection cycles for subset of nodes
|
||||
for i, node := range chaosNodes {
|
||||
for i := range chaosNodes {
|
||||
node := &chaosNodes[i]
|
||||
// Only add work if this is connection chaos or mixed
|
||||
if tc.chaosType == "connection" || tc.chaosType == "mixed" {
|
||||
wg.Add(2)
|
||||
@@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
channel,
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
@@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
t.Logf("Created %d nodes in database", len(allNodes))
|
||||
|
||||
// Connect nodes one at a time and wait for each to be connected
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
|
||||
// Wait for node to be connected
|
||||
@@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
|
||||
|
||||
// Check how many peers each node should see
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
peers := testData.State.ListPeers(node.n.ID)
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
@@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
for i, node := range allNodes {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||
@@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
for i, node := range allNodes {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
@@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
@@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
@@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
@@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo // complex multi-connection test scenario
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
@@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
node1 := &testData.Nodes[0]
|
||||
node2 := &testData.Nodes[1]
|
||||
|
||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
@@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
@@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
@@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
@@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
@@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -36,6 +35,7 @@ const (
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
@@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
return nil, ErrNodeNotFoundMapper
|
||||
}
|
||||
|
||||
// Get unreduced matchers for peer relationship determination.
|
||||
@@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs))
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
// Build finalizes the response and returns marshaled bytes.
|
||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
||||
@@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build()
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
require.Nil(t, data)
|
||||
require.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
@@ -50,6 +49,7 @@ type mapper struct {
|
||||
created time.Time
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
type patch struct {
|
||||
timestamp time.Time
|
||||
change *tailcfg.PeerChange
|
||||
@@ -60,7 +60,6 @@ func newMapper(
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
@@ -76,23 +75,26 @@ func generateUserProfiles(
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.UserView)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
|
||||
user := node.Owner()
|
||||
if !user.Valid() {
|
||||
log.Error().
|
||||
Uint64("node.id", node.ID().Uint64()).
|
||||
Str("node.name", node.Hostname()).
|
||||
EmbedObject(node).
|
||||
Msg("node has no valid owner, skipping user profile generation")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
userID := user.Model().ID
|
||||
userMap[userID] = &user
|
||||
ids = append(ids, userID)
|
||||
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.Owner()
|
||||
if !peerUser.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
peerUserID := peerUser.Model().ID
|
||||
userMap[peerUserID] = &peerUser
|
||||
ids = append(ids, peerUserID)
|
||||
@@ -100,7 +102,9 @@ func generateUserProfiles(
|
||||
|
||||
slices.Sort(ids)
|
||||
ids = slices.Compact(ids)
|
||||
|
||||
var profiles []tailcfg.UserProfile
|
||||
|
||||
for _, id := range ids {
|
||||
if userMap[id] != nil {
|
||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||
@@ -150,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
//
|
||||
//nolint:unused
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
@@ -317,6 +323,7 @@ func writeDebugMapResponse(
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -329,7 +336,8 @@ func writeDebugMapResponse(
|
||||
fmt.Sprintf("%s-%s.json", now, t),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
log.Trace().Msgf("writing MapResponse to %s", mapResponsePath)
|
||||
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -338,7 +346,7 @@ func writeDebugMapResponse(
|
||||
|
||||
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
if debugDumpMapResponsePath == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no data when debug path not set
|
||||
}
|
||||
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
@@ -351,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
}
|
||||
|
||||
result := make(map[types.NodeID][]tailcfg.MapResponse)
|
||||
|
||||
for _, node := range nodes {
|
||||
if !node.IsDir() {
|
||||
continue
|
||||
@@ -358,7 +367,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
|
||||
nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name())
|
||||
log.Error().Err(err).Msgf("parsing node ID from dir %s", node.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -366,7 +375,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
|
||||
files, err := os.ReadDir(path.Join(dir, node.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
|
||||
log.Error().Err(err).Msgf("reading dir %s", node.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -381,14 +390,15 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
|
||||
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Reading file %s", file.Name())
|
||||
log.Error().Err(err).Msgf("reading file %s", file.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())
|
||||
log.Error().Err(err).Msgf("unmarshalling file %s", file.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -3,18 +3,13 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var iap = func(ipStr string) *netip.Addr {
|
||||
@@ -51,7 +46,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
mach := func(hostname, username string, userid uint) *types.Node {
|
||||
return &types.Node{
|
||||
Hostname: hostname,
|
||||
UserID: ptr.To(userid),
|
||||
UserID: new(userid),
|
||||
User: &types.User{
|
||||
Name: username,
|
||||
},
|
||||
@@ -81,90 +76,3 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockState is a mock implementation that provides the required methods.
|
||||
type mockState struct {
|
||||
polMan policy.PolicyManager
|
||||
derpMap *tailcfg.DERPMap
|
||||
primary *routes.PrimaryRoutes
|
||||
nodes types.Nodes
|
||||
peers types.Nodes
|
||||
}
|
||||
|
||||
func (m *mockState) DERPMap() *tailcfg.DERPMap {
|
||||
return m.derpMap
|
||||
}
|
||||
|
||||
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
if m.polMan == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
return m.polMan.Filter()
|
||||
}
|
||||
|
||||
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
if m.polMan == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.polMan.SSHPolicy(node)
|
||||
}
|
||||
|
||||
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
if m.polMan == nil {
|
||||
return false
|
||||
}
|
||||
return m.polMan.NodeCanHaveTag(node, tag)
|
||||
}
|
||||
|
||||
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
|
||||
if m.primary == nil {
|
||||
return nil
|
||||
}
|
||||
return m.primary.PrimaryRoutes(nodeID)
|
||||
}
|
||||
|
||||
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(peerIDs) > 0 {
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
// Return all peers except the node itself
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if peer.ID != nodeID {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(nodeIDs) > 0 {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return m.nodes, nil
|
||||
}
|
||||
|
||||
func Test_fullMapResponse(t *testing.T) {
|
||||
t.Skip("Test needs to be refactored for new state-based architecture")
|
||||
// TODO: Refactor this test to work with the new state-based mapper
|
||||
// The test architecture needs to be updated to work with the state interface
|
||||
// instead of the old direct dependency injection pattern
|
||||
}
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestTailNode(t *testing.T) {
|
||||
mustNK := func(str string) key.NodePublic {
|
||||
var k key.NodePublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -26,6 +26,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
mustDK := func(str string) key.DiscoPublic {
|
||||
var k key.DiscoPublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -33,6 +34,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
mustMK := func(str string) key.MachinePublic {
|
||||
var k key.MachinePublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -95,7 +97,7 @@ func TestTailNode(t *testing.T) {
|
||||
IPv4: iap("100.64.0.1"),
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
UserID: ptr.To(uint(0)),
|
||||
UserID: new(uint(0)),
|
||||
User: &types.User{
|
||||
Name: "mini",
|
||||
},
|
||||
@@ -137,8 +139,8 @@ func TestTailNode(t *testing.T) {
|
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
|
||||
AllowedIPs: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
PrimaryRoutes: []netip.Prefix{
|
||||
@@ -255,7 +257,7 @@ func TestNodeExpiry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "localtime",
|
||||
exp: tp(time.Time{}.Local()),
|
||||
exp: tp(time.Time{}.Local()), //nolint:gosmopolitan
|
||||
wantTimeZero: true,
|
||||
},
|
||||
}
|
||||
@@ -284,7 +286,9 @@ func TestNodeExpiry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("nodeExpiry() error = %v", err)
|
||||
}
|
||||
|
||||
var deseri tailcfg.Node
|
||||
|
||||
err = json.Unmarshal(seri, &deseri)
|
||||
if err != nil {
|
||||
t.Fatalf("nodeExpiry() error = %v", err)
|
||||
|
||||
@@ -71,6 +71,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
|
||||
rw := &respWriterProm{ResponseWriter: w}
|
||||
|
||||
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
timer.ObserveDuration()
|
||||
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
|
||||
@@ -79,6 +80,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
type respWriterProm struct {
|
||||
http.ResponseWriter
|
||||
|
||||
status int
|
||||
written int64
|
||||
wroteHeader bool
|
||||
@@ -94,6 +96,7 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
|
||||
if !r.wroteHeader {
|
||||
r.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
n, err := r.ResponseWriter.Write(b)
|
||||
r.written += int64(n)
|
||||
|
||||
|
||||
@@ -19,6 +19,9 @@ import (
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
|
||||
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
|
||||
|
||||
const (
|
||||
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
||||
ts2021UpgradePath = "/ts2021"
|
||||
@@ -51,7 +54,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", req.RemoteAddr)
|
||||
log.Trace().Caller().Msgf("noise upgrade handler for client %s", req.RemoteAddr)
|
||||
|
||||
upgrade := req.Header.Get("Upgrade")
|
||||
if upgrade == "" {
|
||||
@@ -60,7 +63,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
// be passed to Headscale. Let's give them a hint.
|
||||
log.Warn().
|
||||
Caller().
|
||||
Msg("No Upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
||||
Msg("no upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
||||
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
@@ -79,7 +82,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
noiseServer.earlyNoise,
|
||||
)
|
||||
if err != nil {
|
||||
httpError(writer, fmt.Errorf("noise upgrade failed: %w", err))
|
||||
httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
}
|
||||
|
||||
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
|
||||
return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version)
|
||||
return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
|
||||
}
|
||||
|
||||
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||
@@ -137,17 +140,20 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||
// an HTTP/2 settings frame, which isn't of type 'T')
|
||||
var notH2Frame [5]byte
|
||||
copy(notH2Frame[:], earlyPayloadMagic)
|
||||
|
||||
var lenBuf [4]byte
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON)))
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) //nolint:gosec // JSON length is bounded
|
||||
// These writes are all buffered by caller, so fine to do them
|
||||
// separately:
|
||||
if _, err := writer.Write(notH2Frame[:]); err != nil {
|
||||
if _, err := writer.Write(notH2Frame[:]); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(lenBuf[:]); err != nil {
|
||||
|
||||
if _, err := writer.Write(lenBuf[:]); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(earlyJSON); err != nil {
|
||||
|
||||
if _, err := writer.Write(earlyJSON); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -199,7 +205,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
|
||||
var mapRequest tailcfg.MapRequest
|
||||
if err := json.Unmarshal(body, &mapRequest); err != nil {
|
||||
if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -218,7 +224,8 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
ns.nodeKey = nv.NodeKey()
|
||||
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
||||
sess.log.Trace().Caller().Msg("a node sending a MapRequest with Noise protocol")
|
||||
|
||||
if !sess.isStreaming() {
|
||||
sess.serve()
|
||||
} else {
|
||||
@@ -241,14 +248,16 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
return
|
||||
}
|
||||
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
|
||||
var resp *tailcfg.RegisterResponse
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return &tailcfg.RegisterRequest{}, regErr(err)
|
||||
}
|
||||
|
||||
var regReq tailcfg.RegisterRequest
|
||||
if err := json.Unmarshal(body, ®Req); err != nil {
|
||||
if err := json.Unmarshal(body, ®Req); err != nil { //nolint:noinlineerr
|
||||
return ®Req, regErr(err)
|
||||
}
|
||||
|
||||
@@ -256,11 +265,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
|
||||
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
|
||||
if err != nil {
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
if httpErr, ok := errors.AsType[HTTPError](err); ok {
|
||||
resp = &tailcfg.RegisterResponse{
|
||||
Error: httpErr.Msg,
|
||||
}
|
||||
|
||||
return ®Req, resp
|
||||
}
|
||||
|
||||
@@ -278,8 +287,9 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
err := json.NewEncoder(writer).Encode(registerResponse)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msg("noise registration handler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ const (
|
||||
|
||||
var (
|
||||
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
||||
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
|
||||
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
|
||||
errNoOIDCIDToken = errors.New("extracting ID token")
|
||||
errNoOIDCRegistrationInfo = errors.New("registration info not in cache")
|
||||
errOIDCAllowedDomains = errors.New(
|
||||
"authenticated principal does not match any allowed domain",
|
||||
)
|
||||
@@ -68,7 +68,7 @@ func NewAuthProviderOIDC(
|
||||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) //nolint:contextcheck
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||
}
|
||||
@@ -163,13 +163,14 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
a.registrationCache.Set(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL)
|
||||
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||
|
||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||
}
|
||||
@@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
stateCookieName := getCookieName("state", state)
|
||||
|
||||
cookieState, err := req.Cookie(stateCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||
@@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce == "" {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
||||
return
|
||||
}
|
||||
|
||||
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
||||
|
||||
nonce, err := req.Cookie(nonceCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce != nonce.Value {
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
||||
return
|
||||
@@ -231,7 +236,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
||||
|
||||
var claims types.OIDCClaims
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
if err := idToken.Claims(&claims); err != nil { //nolint:noinlineerr
|
||||
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
|
||||
return
|
||||
}
|
||||
@@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
var userinfo *oidc.UserInfo
|
||||
|
||||
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||
@@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
|
||||
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
|
||||
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
|
||||
|
||||
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
|
||||
if userinfo2.Groups != nil {
|
||||
claims.Groups = userinfo2.Groups
|
||||
@@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
Msgf("could not create or update user")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, werr := writer.Write([]byte("Could not create or update user"))
|
||||
if werr != nil {
|
||||
log.Error().
|
||||
@@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
verb := "Reauthenticated"
|
||||
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
@@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
httpError(writer, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -316,15 +327,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
// TODO(kradalby): replace with go-elem
|
||||
content, err := renderOIDCCallbackTemplate(user, verb)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
content := renderOIDCCallbackTemplate(user, verb)
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||
|
||||
if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr
|
||||
util.LogErr(err, "Failed to write HTTP response")
|
||||
}
|
||||
|
||||
@@ -370,6 +378,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
if !ok {
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
}
|
||||
@@ -377,7 +386,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
|
||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("exchanging code for token: %w", err))
|
||||
}
|
||||
|
||||
return oauth2Token, err
|
||||
@@ -394,9 +403,10 @@ func (a *AuthProviderOIDC) extractIDToken(
|
||||
}
|
||||
|
||||
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
|
||||
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("verifying ID token: %w", err))
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
@@ -516,6 +526,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
newUser bool
|
||||
c change.Change
|
||||
)
|
||||
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
||||
@@ -561,7 +572,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
util.RegisterMethodOIDC,
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("could not register node: %w", err)
|
||||
return false, fmt.Errorf("registering node: %w", err)
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
@@ -589,9 +600,9 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
func renderOIDCCallbackTemplate(
|
||||
user *types.User,
|
||||
verb string,
|
||||
) (*bytes.Buffer, error) {
|
||||
) *bytes.Buffer {
|
||||
html := templates.OIDCCallback(user.Display(), verb).Render()
|
||||
return bytes.NewBufferString(html), nil
|
||||
return bytes.NewBufferString(html)
|
||||
}
|
||||
|
||||
// getCookieName generates a unique cookie name based on a cookie value.
|
||||
|
||||
@@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
|
||||
) {
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||
_, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||
}
|
||||
|
||||
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
|
||||
@@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage(
|
||||
) {
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||
_, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||
}
|
||||
|
||||
func (h *Headscale) ApplePlatformConfig(
|
||||
@@ -37,6 +37,7 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
|
||||
platform, ok := vars["platform"]
|
||||
if !ok {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
|
||||
@@ -64,17 +65,20 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
|
||||
switch platform {
|
||||
case "macos-standalone":
|
||||
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
err := macosStandaloneTemplate.Execute(&payload, platformConfig)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
case "macos-app-store":
|
||||
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
err := macosAppStoreTemplate.Execute(&payload, platformConfig)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
case "ios":
|
||||
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
err := iosTemplate.Execute(&payload, platformConfig)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -90,7 +94,7 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
}
|
||||
|
||||
var content bytes.Buffer
|
||||
if err := commonTemplate.Execute(&content, config); err != nil {
|
||||
if err := commonTemplate.Execute(&content, config); err != nil { //nolint:noinlineerr
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -98,7 +102,7 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
writer.Header().
|
||||
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write(content.Bytes())
|
||||
_, _ = writer.Write(content.Bytes())
|
||||
}
|
||||
|
||||
type AppleMobileConfig struct {
|
||||
|
||||
@@ -16,15 +16,18 @@ type Match struct {
|
||||
dests *netipx.IPSet
|
||||
}
|
||||
|
||||
func (m Match) DebugString() string {
|
||||
func (m *Match) DebugString() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("Match:\n")
|
||||
sb.WriteString(" Sources:\n")
|
||||
|
||||
for _, prefix := range m.srcs.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
||||
sb.WriteString(" Destinations:\n")
|
||||
|
||||
for _, prefix := range m.dests.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
@@ -42,7 +45,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
||||
}
|
||||
|
||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||
dests := []string{}
|
||||
dests := make([]string, 0, len(rule.DstPorts))
|
||||
for _, dest := range rule.DstPorts {
|
||||
dests = append(dests, dest.IP)
|
||||
}
|
||||
@@ -93,11 +96,24 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
|
||||
return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix)
|
||||
}
|
||||
|
||||
// DestsIsTheInternet reports if the destination is equal to "the internet"
|
||||
// DestsIsTheInternet reports if the destination contains "the internet"
|
||||
// which is a IPSet that represents "autogroup:internet" and is special
|
||||
// cased for exit nodes.
|
||||
func (m Match) DestsIsTheInternet() bool {
|
||||
return m.dests.Equal(util.TheInternet()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv6())
|
||||
// This checks if dests is a superset of TheInternet(), which handles
|
||||
// merged filter rules where TheInternet is combined with other destinations.
|
||||
func (m *Match) DestsIsTheInternet() bool {
|
||||
if m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv6()) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if dests contains all prefixes of TheInternet (superset check)
|
||||
theInternet := util.TheInternet()
|
||||
for _, prefix := range theInternet.Prefixes() {
|
||||
if !m.dests.ContainsPrefix(prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -19,18 +19,18 @@ type PolicyManager interface {
|
||||
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
|
||||
// BuildPeerMap constructs peer relationship maps for the given nodes
|
||||
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
|
||||
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy(pol []byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(types.NodeView, string) bool
|
||||
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||
|
||||
// TagExists reports whether the given tag is defined in the policy.
|
||||
TagExists(tag string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
|
||||
NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
@@ -38,8 +38,11 @@ type PolicyManager interface {
|
||||
|
||||
// NewPolicyManager returns a new policy manager.
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
var (
|
||||
polMan PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
@@ -66,7 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
||||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
|
||||
polmanFuncs := make([]func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error), 0, 1)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
return policyv2.NewPolicyManager(pol, u, n)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
@@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
||||
}
|
||||
|
||||
// Sort and deduplicate
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
slices.SortFunc(newApproved, netip.Prefix.Compare)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
@@ -120,12 +119,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
||||
// Sort the current approved for comparison
|
||||
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||
copy(sortedCurrent, currentApproved)
|
||||
tsaddr.SortPrefixes(sortedCurrent)
|
||||
slices.SortFunc(sortedCurrent, netip.Prefix.Compare)
|
||||
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(sortedCurrent, newApproved) {
|
||||
// Log what changed
|
||||
var added, kept []netip.Prefix
|
||||
|
||||
for _, route := range newApproved {
|
||||
if !slices.Contains(sortedCurrent, route) {
|
||||
added = append(added, route)
|
||||
@@ -136,8 +136,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
||||
|
||||
if len(added) > 0 {
|
||||
log.Debug().
|
||||
Uint64("node.id", nv.ID().Uint64()).
|
||||
Str("node.name", nv.Hostname()).
|
||||
EmbedObject(nv).
|
||||
Strs("routes.added", util.PrefixesToString(added)).
|
||||
Strs("routes.kept", util.PrefixesToString(kept)).
|
||||
Int("routes.total", len(newApproved)).
|
||||
|
||||
@@ -3,16 +3,16 @@ package policy
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
@@ -32,10 +32,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "test-node",
|
||||
UserID: ptr.To(user1.ID),
|
||||
User: ptr.To(user1),
|
||||
UserID: new(user1.ID),
|
||||
User: new(user1),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
Tags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
@@ -44,10 +44,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "other-node",
|
||||
UserID: ptr.To(user2.ID),
|
||||
User: ptr.To(user2),
|
||||
UserID: new(user2.ID),
|
||||
User: new(user2),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
// Create a policy that auto-approves specific routes
|
||||
@@ -76,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
}`
|
||||
|
||||
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -194,7 +194,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||
|
||||
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||
|
||||
// Verify that all previously approved routes are still present
|
||||
@@ -304,20 +304,23 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager or use nil if specified
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
if tt.name != "nil_policy_manager" {
|
||||
pm, err = pmf(users, nodes.ViewSlice())
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
pm = nil
|
||||
}
|
||||
@@ -330,7 +333,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||
if tt.wantApproved == nil {
|
||||
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||
} else {
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
@@ -92,8 +91,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||
nodeUser: "test",
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
@@ -124,8 +123,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
nodeUser: "test",
|
||||
nodeTags: []string{"tag:approved"},
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
@@ -168,13 +167,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: tt.nodeHostname,
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
Tags: tt.nodeTags,
|
||||
}
|
||||
@@ -294,13 +293,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
@@ -343,13 +342,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: currentApproved,
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var ap = func(ipStr string) *netip.Addr {
|
||||
@@ -33,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
|
||||
rules []tailcfg.FilterRule
|
||||
node *types.Node
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -783,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
|
||||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
@@ -796,7 +798,7 @@ func TestReduceNodes(t *testing.T) {
|
||||
|
||||
func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
|
||||
var routes []netip.Prefix
|
||||
routes := make([]netip.Prefix, 0, len(routess))
|
||||
for _, route := range routess {
|
||||
routes = append(routes, netip.MustParsePrefix(route))
|
||||
}
|
||||
@@ -891,11 +893,13 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
]
|
||||
}`,
|
||||
node: n(1, "100.64.0.1", "mobile", "mobile"),
|
||||
// autogroup:internet does not generate packet filters - it's handled
|
||||
// by exit node routing via AllowedIPs, not by packet filtering.
|
||||
// Only server is visible through the mobile -> server:80 rule.
|
||||
want: types.Nodes{
|
||||
n(2, "100.64.0.2", "server", "server"),
|
||||
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
||||
},
|
||||
wantMatchers: 2,
|
||||
wantMatchers: 1,
|
||||
},
|
||||
{
|
||||
name: "2788-exit-node-0000-route",
|
||||
@@ -938,7 +942,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
n(2, "100.64.0.2", "server", "server"),
|
||||
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
||||
},
|
||||
wantMatchers: 2,
|
||||
wantMatchers: 1,
|
||||
},
|
||||
{
|
||||
name: "2788-exit-node-::0-route",
|
||||
@@ -981,7 +985,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
n(2, "100.64.0.2", "server", "server"),
|
||||
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
||||
},
|
||||
wantMatchers: 2,
|
||||
wantMatchers: 1,
|
||||
},
|
||||
{
|
||||
name: "2784-split-exit-node-access",
|
||||
@@ -1032,8 +1036,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(nil, tt.nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1051,9 +1058,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
@@ -1074,21 +1083,21 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
nodeUser1 := types.Node{
|
||||
Hostname: "user1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
UserID: ptr.To(uint(1)),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeUser2 := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
taggedClient := types.Node{
|
||||
Hostname: "tagged-client",
|
||||
IPv4: ap("100.64.0.4"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
Tags: []string{"tag:client"},
|
||||
}
|
||||
|
||||
@@ -1096,8 +1105,8 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: ap("100.64.0.5"),
|
||||
UserID: ptr.To(uint(1)),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
|
||||
@@ -1231,7 +1240,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
]
|
||||
}`,
|
||||
expectErr: true,
|
||||
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
|
||||
errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
|
||||
},
|
||||
{
|
||||
name: "invalid-check-period",
|
||||
@@ -1278,7 +1287,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
]
|
||||
}`,
|
||||
expectErr: true,
|
||||
errorMessage: "autogroup \"autogroup:invalid\" is not supported",
|
||||
errorMessage: "autogroup not supported for SSH user",
|
||||
},
|
||||
{
|
||||
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
|
||||
@@ -1451,13 +1460,17 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMessage)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1480,6 +1493,7 @@ func TestReduceRoutes(t *testing.T) {
|
||||
routes []netip.Prefix
|
||||
rules []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -2101,6 +2115,7 @@ func TestReduceRoutes(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
|
||||
got := ReduceRoutes(
|
||||
tt.args.node.View(),
|
||||
tt.args.routes,
|
||||
|
||||
@@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/policyutil"
|
||||
v2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
@@ -144,13 +144,13 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{},
|
||||
@@ -191,7 +191,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
@@ -202,10 +202,11 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
// Merged: Both ACL rules combined (same SrcIPs and IPProto)
|
||||
{
|
||||
SrcIPs: []string{
|
||||
"100.64.0.1/32",
|
||||
@@ -222,23 +223,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
IP: "fd7a:115c:a1e0::1/128",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{
|
||||
"100.64.0.1/32",
|
||||
"100.64.0.2/32",
|
||||
"fd7a:115c:a1e0::1/128",
|
||||
"fd7a:115c:a1e0::2/128",
|
||||
},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "10.33.0.0/16",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
IPProto: []int{v2.ProtocolTCP, v2.ProtocolUDP, v2.ProtocolICMP, v2.ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -283,19 +273,19 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
// "internal" exit node
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
@@ -344,7 +334,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
@@ -353,15 +343,18 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
// Only the internal:* rule generates filters.
|
||||
// autogroup:internet does NOT generate packet filters - it's handled
|
||||
// by exit node routing via AllowedIPs, not by packet filtering.
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
@@ -374,12 +367,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: hsExitNodeDestForTest,
|
||||
IPProto: []int{6, 17},
|
||||
IPProto: []int{v2.ProtocolTCP, v2.ProtocolUDP, v2.ProtocolICMP, v2.ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -453,7 +441,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
@@ -462,15 +450,16 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
// Merged: Both ACL rules combined (same SrcIPs and IPProto)
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
@@ -482,12 +471,6 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
IP: "fd7a:115c:a1e0::100/128",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
|
||||
@@ -519,7 +502,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
{IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
IPProto: []int{v2.ProtocolTCP, v2.ProtocolUDP, v2.ProtocolICMP, v2.ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -565,7 +548,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
||||
},
|
||||
@@ -574,15 +557,16 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
// Merged: Both ACL rules combined (same SrcIPs and IPProto)
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
@@ -594,12 +578,6 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
IP: "fd7a:115c:a1e0::100/128",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "8.0.0.0/8",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
@@ -609,7 +587,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
IPProto: []int{v2.ProtocolTCP, v2.ProtocolUDP, v2.ProtocolICMP, v2.ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -655,7 +633,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
||||
},
|
||||
@@ -664,15 +642,16 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
// Merged: Both ACL rules combined (same SrcIPs and IPProto)
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
@@ -684,12 +663,6 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
IP: "fd7a:115c:a1e0::100/128",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "8.0.0.0/16",
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
@@ -699,7 +672,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
IPProto: []int{v2.ProtocolTCP, v2.ProtocolUDP, v2.ProtocolICMP, v2.ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -737,7 +710,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||
},
|
||||
@@ -747,7 +720,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
@@ -767,7 +740,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
Ports: tailcfg.PortRangeAny,
|
||||
},
|
||||
},
|
||||
IPProto: []int{6, 17},
|
||||
IPProto: []int{v2.ProtocolTCP, v2.ProtocolUDP, v2.ProtocolICMP, v2.ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -804,13 +777,13 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
|
||||
},
|
||||
@@ -824,10 +797,14 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm policy.PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm policy.PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
got, _ := pm.Filter()
|
||||
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
|
||||
got = policyutil.ReduceFilterRules(tt.node.View(), got)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestNodeCanApproveRoute(t *testing.T) {
|
||||
@@ -25,24 +24,24 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
ID: 1,
|
||||
Hostname: "user1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
UserID: ptr.To(uint(1)),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
}
|
||||
|
||||
exitNode := types.Node{
|
||||
ID: 2,
|
||||
Hostname: "user2-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
taggedNode := types.Node{
|
||||
ID: 3,
|
||||
Hostname: "tagged-server",
|
||||
IPv4: ap("100.64.0.3"),
|
||||
UserID: ptr.To(uint(3)),
|
||||
User: ptr.To(users[2]),
|
||||
UserID: new(uint(3)),
|
||||
User: new(users[2]),
|
||||
Tags: []string{"tag:router"},
|
||||
}
|
||||
|
||||
@@ -50,8 +49,8 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
ID: 4,
|
||||
Hostname: "multi-tag-node",
|
||||
IPv4: ap("100.64.0.4"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
Tags: []string{"tag:router", "tag:server"},
|
||||
}
|
||||
|
||||
@@ -830,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
if tt.name == "empty policy" {
|
||||
// We expect this one to have a valid but empty policy
|
||||
require.NoError(t, err)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -844,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
|
||||
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.canApprove, result, "Unexpected route approval result")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ package v2
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@@ -14,7 +17,10 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var ErrInvalidAction = errors.New("invalid action")
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
errSelfInSources = errors.New("autogroup:self cannot be used in sources")
|
||||
)
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
@@ -42,10 +48,29 @@ func (pol *Policy) compileFilterRules(
|
||||
continue
|
||||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
protocols := acl.Protocol.parseProtocol()
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
// Check if destination is a wildcard - use "*" directly instead of expanding
|
||||
if _, isWildcard := dest.Alias.(Asterix); isWildcard {
|
||||
for _, port := range dest.Ports {
|
||||
destPorts = append(destPorts, tailcfg.NetPortRange{
|
||||
IP: "*",
|
||||
Ports: port,
|
||||
})
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// autogroup:internet does not generate packet filters - it's handled
|
||||
// by exit node routing via AllowedIPs, not by packet filtering.
|
||||
if ag, isAutoGroup := dest.Alias.(*AutoGroup); isAutoGroup && ag.Is(AutoGroupInternet) {
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
@@ -80,7 +105,7 @@ func (pol *Policy) compileFilterRules(
|
||||
})
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
return mergeFilterRules(rules), nil
|
||||
}
|
||||
|
||||
// compileFilterRulesForNode compiles filter rules for a specific node.
|
||||
@@ -113,7 +138,7 @@ func (pol *Policy) compileFilterRulesForNode(
|
||||
}
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
return mergeFilterRules(rules), nil
|
||||
}
|
||||
|
||||
// compileACLWithAutogroupSelf compiles a single ACL rule, handling
|
||||
@@ -121,14 +146,18 @@ func (pol *Policy) compileFilterRulesForNode(
|
||||
// It returns a slice of filter rules because when an ACL has both autogroup:self
|
||||
// and other destinations, they need to be split into separate rules with different
|
||||
// source filtering logic.
|
||||
//
|
||||
//nolint:gocyclo // complex ACL compilation logic
|
||||
func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
acl ACL,
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) ([]*tailcfg.FilterRule, error) {
|
||||
var autogroupSelfDests []AliasWithPorts
|
||||
var otherDests []AliasWithPorts
|
||||
var (
|
||||
autogroupSelfDests []AliasWithPorts
|
||||
otherDests []AliasWithPorts
|
||||
)
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
@@ -138,14 +167,15 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
}
|
||||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
protocols := acl.Protocol.parseProtocol()
|
||||
|
||||
var rules []*tailcfg.FilterRule
|
||||
|
||||
var resolvedSrcIPs []*netipx.IPSet
|
||||
|
||||
for _, src := range acl.Sources {
|
||||
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return nil, fmt.Errorf("autogroup:self cannot be used in sources")
|
||||
return nil, errSelfInSources
|
||||
}
|
||||
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
@@ -167,6 +197,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
@@ -176,6 +207,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
if len(sameUserNodes) > 0 {
|
||||
// Filter sources to only same-user untagged devices
|
||||
var srcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, ips := range resolvedSrcIPs {
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
@@ -192,12 +224,13 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
|
||||
if srcSet != nil && len(srcSet.Prefixes()) > 0 {
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range autogroupSelfDests {
|
||||
for _, n := range sameUserNodes {
|
||||
for _, port := range dest.Ports {
|
||||
for _, ip := range n.IPs() {
|
||||
destPorts = append(destPorts, tailcfg.NetPortRange{
|
||||
IP: ip.String(),
|
||||
IP: netip.PrefixFrom(ip, ip.BitLen()).String(),
|
||||
Ports: port,
|
||||
})
|
||||
}
|
||||
@@ -232,6 +265,24 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range otherDests {
|
||||
// Check if destination is a wildcard - use "*" directly instead of expanding
|
||||
if _, isWildcard := dest.Alias.(Asterix); isWildcard {
|
||||
for _, port := range dest.Ports {
|
||||
destPorts = append(destPorts, tailcfg.NetPortRange{
|
||||
IP: "*",
|
||||
Ports: port,
|
||||
})
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// autogroup:internet does not generate packet filters - it's handled
|
||||
// by exit node routing via AllowedIPs, not by packet filtering.
|
||||
if ag, isAutoGroup := dest.Alias.(*AutoGroup); isAutoGroup && ag.Is(AutoGroupInternet) {
|
||||
continue
|
||||
}
|
||||
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
@@ -279,13 +330,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo // complex SSH policy compilation logic
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no SSH policy when none configured
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
@@ -296,8 +348,10 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// Separate destinations into autogroup:self and others
|
||||
// This is needed because autogroup:self requires filtering sources to same-user only,
|
||||
// while other destinations should use all resolved sources
|
||||
var autogroupSelfDests []Alias
|
||||
var otherDests []Alias
|
||||
var (
|
||||
autogroupSelfDests []Alias
|
||||
otherDests []Alias
|
||||
)
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
@@ -312,7 +366,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// Resolve sources once - we'll use them differently for each destination type
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
log.Trace().Caller().Err(err).Msgf("ssh policy compilation failed resolving source ips for rule %+v", rule)
|
||||
}
|
||||
|
||||
if srcIPs == nil || len(srcIPs.Prefixes()) == 0 {
|
||||
@@ -320,6 +374,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
|
||||
switch rule.Action {
|
||||
case SSHActionAccept:
|
||||
action = sshAction(true, 0)
|
||||
@@ -335,9 +390,11 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// by default, we do not allow root unless explicitly stated
|
||||
userMap["root"] = ""
|
||||
}
|
||||
|
||||
if rule.Users.ContainsRoot() {
|
||||
userMap["root"] = "root"
|
||||
}
|
||||
|
||||
for _, u := range rule.Users.NormalUsers() {
|
||||
userMap[u.String()] = u.String()
|
||||
}
|
||||
@@ -347,6 +404,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Build destination set for autogroup:self (same-user untagged devices only)
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
n.AppendToIPSet(&dest)
|
||||
@@ -363,6 +421,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// Filter sources to only same-user untagged devices
|
||||
// Pre-filter to same-user untagged devices for efficiency
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
@@ -370,6 +429,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
|
||||
var filteredSrcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
@@ -405,11 +465,13 @@ func (pol *Policy) compileSSHPolicy(
|
||||
if len(otherDests) > 0 {
|
||||
// Build destination set for other destinations
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, dst := range otherDests {
|
||||
ips, err := dst.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
if ips != nil {
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
@@ -459,3 +521,45 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// filterRuleKey generates a unique key for merging based on SrcIPs and IPProto.
|
||||
func filterRuleKey(rule tailcfg.FilterRule) string {
|
||||
srcKey := strings.Join(rule.SrcIPs, ",")
|
||||
|
||||
protoStrs := make([]string, len(rule.IPProto))
|
||||
for i, p := range rule.IPProto {
|
||||
protoStrs[i] = strconv.Itoa(p)
|
||||
}
|
||||
|
||||
return srcKey + "|" + strings.Join(protoStrs, ",")
|
||||
}
|
||||
|
||||
// mergeFilterRules merges rules with identical SrcIPs and IPProto by combining
|
||||
// their DstPorts. DstPorts are NOT deduplicated to match Tailscale behavior.
|
||||
func mergeFilterRules(rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
if len(rules) <= 1 {
|
||||
return rules
|
||||
}
|
||||
|
||||
keyToIdx := make(map[string]int)
|
||||
result := make([]tailcfg.FilterRule, 0, len(rules))
|
||||
|
||||
for _, rule := range rules {
|
||||
key := filterRuleKey(rule)
|
||||
|
||||
if idx, exists := keyToIdx[key]; exists {
|
||||
// Merge: append DstPorts to existing rule
|
||||
result[idx].DstPorts = append(result[idx].DstPorts, rule.DstPorts...)
|
||||
} else {
|
||||
// New unique combination
|
||||
keyToIdx[key] = len(result)
|
||||
result = append(result, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: slices.Clone(rule.DstPorts),
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
|
||||
@@ -97,13 +96,11 @@ func TestParsing(t *testing.T) {
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "*", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "*", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP, protocolUDP},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -153,25 +150,26 @@ func TestParsing(t *testing.T) {
|
||||
}`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
SrcIPs: []string{"100.64.0.0/10", "fd7a:115c:a1e0::/48"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
SrcIPs: []string{"100.64.0.0/10", "fd7a:115c:a1e0::/48"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{protocolUDP},
|
||||
IPProto: []int{ProtocolUDP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
SrcIPs: []string{"100.64.0.0/10", "fd7a:115c:a1e0::/48"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolICMP, protocolIPv6ICMP},
|
||||
// proto:icmp only includes ICMP (1), not ICMPv6 (58)
|
||||
IPProto: []int{ProtocolICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -201,11 +199,11 @@ func TestParsing(t *testing.T) {
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
SrcIPs: []string{"100.64.0.0/10", "fd7a:115c:a1e0::/48"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP, protocolUDP},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -242,7 +240,7 @@ func TestParsing(t *testing.T) {
|
||||
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
|
||||
},
|
||||
},
|
||||
IPProto: []int{protocolTCP, protocolUDP},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -282,7 +280,7 @@ func TestParsing(t *testing.T) {
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP, protocolUDP},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -316,7 +314,7 @@ func TestParsing(t *testing.T) {
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP, protocolUDP},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -346,11 +344,11 @@ func TestParsing(t *testing.T) {
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
SrcIPs: []string{"100.64.0.0/10", "fd7a:115c:a1e0::/48"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP, protocolUDP},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -412,15 +410,15 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
nodeTaggedDB := types.Node{
|
||||
Hostname: "tagged-db",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
Tags: []string{"tag:database"},
|
||||
}
|
||||
// Add untagged node for user2 - this will be the SSH source
|
||||
@@ -428,8 +426,8 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
nodeUser2Untagged := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: createAddr("100.64.0.3"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeTaggedServer, &nodeTaggedDB, &nodeUser2Untagged}
|
||||
@@ -624,7 +622,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
if sshPolicy == nil {
|
||||
return // Expected empty result
|
||||
}
|
||||
|
||||
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -657,15 +657,15 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
nodeUser2 := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeTaggedServer, &nodeUser2}
|
||||
@@ -710,7 +710,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers.
|
||||
func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
// Create users matching the integration test
|
||||
users := types.Users{
|
||||
@@ -722,15 +722,15 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
node1 := &types.Node{
|
||||
Hostname: "user1-node",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
}
|
||||
|
||||
node2 := &types.Node{
|
||||
Hostname: "user2-node",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{node1, node2}
|
||||
@@ -776,7 +776,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
|
||||
// to JSON and that the sshUsers field is not empty
|
||||
// to JSON and that the sshUsers field is not empty.
|
||||
func TestSSHJSONSerialization(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
@@ -816,6 +816,7 @@ func TestSSHJSONSerialization(t *testing.T) {
|
||||
|
||||
// Parse back to verify structure
|
||||
var parsed tailcfg.SSHPolicy
|
||||
|
||||
err = json.Unmarshal(jsonData, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -846,19 +847,19 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
|
||||
nodes := types.Nodes{
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.3"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.4"),
|
||||
},
|
||||
// Tagged device for user1
|
||||
@@ -900,6 +901,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(rules) != 1 {
|
||||
t.Fatalf("expected 1 rule, got %d", len(rules))
|
||||
}
|
||||
@@ -916,6 +918,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
found := false
|
||||
|
||||
addr := netip.MustParseAddr(expectedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
@@ -933,6 +936,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
|
||||
for _, excludedIP := range excludedSourceIPs {
|
||||
addr := netip.MustParseAddr(excludedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
@@ -941,7 +945,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
expectedDestIPs := []string{"100.64.0.1", "100.64.0.2"}
|
||||
expectedDestIPs := []string{"100.64.0.1/32", "100.64.0.2/32"}
|
||||
|
||||
actualDestIPs := make([]string, 0, len(rule.DstPorts))
|
||||
for _, dst := range rule.DstPorts {
|
||||
@@ -957,7 +961,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify that other users' devices and tagged devices are not in destinations
|
||||
excludedDestIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
|
||||
excludedDestIPs := []string{"100.64.0.3/32", "100.64.0.4/32", "100.64.0.5/32", "100.64.0.6/32"}
|
||||
for _, excludedIP := range excludedDestIPs {
|
||||
for _, actualIP := range actualDestIPs {
|
||||
if actualIP == excludedIP {
|
||||
@@ -978,11 +982,11 @@ func TestTagUserMutualExclusivity(t *testing.T) {
|
||||
nodes := types.Nodes{
|
||||
// User-owned nodes
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
},
|
||||
// Tagged nodes
|
||||
@@ -1000,8 +1004,8 @@ func TestTagUserMutualExclusivity(t *testing.T) {
|
||||
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
|
||||
Tag("tag:server"): Owners{new(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{new(Username("user2@"))},
|
||||
},
|
||||
ACLs: []ACL{
|
||||
// Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes
|
||||
@@ -1096,11 +1100,11 @@ func TestAutogroupTagged(t *testing.T) {
|
||||
nodes := types.Nodes{
|
||||
// User-owned nodes (not tagged)
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
},
|
||||
// Tagged nodes
|
||||
@@ -1123,10 +1127,10 @@ func TestAutogroupTagged(t *testing.T) {
|
||||
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
|
||||
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:server"): Owners{new(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{new(Username("user2@"))},
|
||||
Tag("tag:web"): Owners{new(Username("user1@"))},
|
||||
Tag("tag:prod"): Owners{new(Username("user1@"))},
|
||||
},
|
||||
ACLs: []ACL{
|
||||
// Rule: autogroup:tagged can reach user-owned nodes
|
||||
@@ -1145,7 +1149,8 @@ func TestAutogroupTagged(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify autogroup:tagged includes all tagged nodes
|
||||
taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice())
|
||||
ag := AutoGroupTagged
|
||||
taggedIPs, err := ag.Resolve(policy, users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, taggedIPs)
|
||||
|
||||
@@ -1246,10 +1251,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) {
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1294,7 +1299,8 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) {
|
||||
actualDestIPs = append(actualDestIPs, dst.IP)
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, expectedSourceIPs, actualDestIPs)
|
||||
expectedDestIPs := []string{"100.64.0.1/32", "100.64.0.2/32"}
|
||||
assert.ElementsMatch(t, expectedDestIPs, actualDestIPs)
|
||||
|
||||
node2 := nodes[2].View()
|
||||
rules2, err := policy.compileFilterRulesForNode(users, node2, nodes.ViewSlice())
|
||||
@@ -1313,11 +1319,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[2]), IPv4: ap("100.64.0.5")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1366,14 +1372,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
|
||||
assert.Empty(t, rules3, "user3 should have no rules")
|
||||
}
|
||||
|
||||
// Helper function to create IP addresses for testing
|
||||
// Helper function to create IP addresses for testing.
|
||||
func createAddr(ip string) *netip.Addr {
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
return &addr
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly
|
||||
// with autogroup:self in destinations
|
||||
// with autogroup:self in destinations.
|
||||
func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1382,13 +1388,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
|
||||
nodes := types.Nodes{
|
||||
// User1's nodes
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
|
||||
// User2's nodes
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
|
||||
// Tagged node for user1 (should be excluded)
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1421,6 +1427,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// Test for user2's first node
|
||||
@@ -1439,12 +1446,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
for i, p := range rule2.Principals {
|
||||
principalIPs2[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2)
|
||||
|
||||
// Test for tagged node (should have no SSH rules)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy3 != nil {
|
||||
assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self")
|
||||
}
|
||||
@@ -1452,7 +1461,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
|
||||
// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user
|
||||
// is in the source and autogroup:self in destination, only that user's devices
|
||||
// can SSH (and only if they match the target user)
|
||||
// can SSH (and only if they match the target user).
|
||||
func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1460,10 +1469,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1494,18 +1503,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user2's node: should have no rules (user1's devices can't match user2's self)
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations.
|
||||
func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1514,11 +1525,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[2]), IPv4: ap("100.64.0.5")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1552,29 +1563,31 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user3's node: should have no rules (not in group:admins)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices
|
||||
// are excluded from both sources and destinations when autogroup:self is used
|
||||
// are excluded from both sources and destinations when autogroup:self is used.
|
||||
func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1609,6 +1622,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs,
|
||||
"should only include untagged devices")
|
||||
|
||||
@@ -1616,6 +1630,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self")
|
||||
}
|
||||
@@ -1631,10 +1646,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
@@ -1664,10 +1679,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
// Verify autogroup:self rule has filtered sources (only same-user devices)
|
||||
selfRule := sshPolicy1.Rules[0]
|
||||
require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices")
|
||||
|
||||
selfPrincipals := make([]string, len(selfRule.Principals))
|
||||
for i, p := range selfRule.Principals {
|
||||
selfPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals,
|
||||
"autogroup:self rule should only include same-user untagged devices")
|
||||
|
||||
@@ -1679,10 +1696,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
|
||||
|
||||
routerRule := sshPolicyRouter.Rules[0]
|
||||
|
||||
routerPrincipals := make([]string, len(routerRule.Principals))
|
||||
for i, p := range routerRule.Principals {
|
||||
routerPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)")
|
||||
@@ -1702,11 +1721,11 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) {
|
||||
|
||||
nodes := types.Nodes{
|
||||
// superadmin's device
|
||||
{ID: 1, User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "superadmin-device"},
|
||||
{ID: 1, User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "superadmin-device"},
|
||||
// admin's device
|
||||
{ID: 2, User: ptr.To(users[1]), IPv4: ap("100.64.0.2"), Hostname: "admin-device"},
|
||||
{ID: 2, User: new(users[1]), IPv4: ap("100.64.0.2"), Hostname: "admin-device"},
|
||||
// direction's device
|
||||
{ID: 3, User: ptr.To(users[2]), IPv4: ap("100.64.0.3"), Hostname: "direction-device"},
|
||||
{ID: 3, User: new(users[2]), IPv4: ap("100.64.0.3"), Hostname: "direction-device"},
|
||||
// tagged servers
|
||||
{ID: 4, IPv4: ap("100.64.0.10"), Hostname: "common-server", Tags: []string{"tag:common"}},
|
||||
{ID: 5, IPv4: ap("100.64.0.11"), Hostname: "tech-server", Tags: []string{"tag:tech"}},
|
||||
@@ -1860,3 +1879,212 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) {
|
||||
assert.True(t, containsSrcIP(directionRules, "100.64.0.1"),
|
||||
"superadmin's IP should be in sources for rule 1 (partial resolution preserved)")
|
||||
}
|
||||
|
||||
func TestMergeFilterRules(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []tailcfg.FilterRule
|
||||
want []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []tailcfg.FilterRule{},
|
||||
want: []tailcfg.FilterRule{},
|
||||
},
|
||||
{
|
||||
name: "single rule unchanged",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge two rules with same key",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different SrcIPs not merged",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IPProto not merged",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{ProtocolUDP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{ProtocolUDP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DstPorts combined without deduplication",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "100.64.0.2/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge three rules with same key",
|
||||
input: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "100.64.0.3/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
|
||||
{IP: "100.64.0.4/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
|
||||
},
|
||||
IPProto: []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mergeFilterRules(tt.input)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("mergeFilterRules() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Filter: filter,
|
||||
Policy: pm.pol,
|
||||
})
|
||||
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
if filterChanged {
|
||||
log.Debug().
|
||||
@@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Int("filter.rules.new", len(filter)).
|
||||
Msg("Policy filter hash changed")
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
|
||||
@@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
if tagOwnerChanged {
|
||||
log.Debug().
|
||||
@@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Int("tagOwners.new", len(tagMap)).
|
||||
Msg("Tag owner hash changed")
|
||||
}
|
||||
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
@@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
if autoApproveChanged {
|
||||
log.Debug().
|
||||
@@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Int("autoApprovers.new", len(autoMap)).
|
||||
Msg("Auto-approvers hash changed")
|
||||
}
|
||||
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
exitSetHash := deephash.Hash(&exitSet)
|
||||
|
||||
exitSetChanged := exitSetHash != pm.exitSetHash
|
||||
if exitSetChanged {
|
||||
log.Debug().
|
||||
@@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Str("exitSet.hash.new", exitSetHash.String()[:8]).
|
||||
Msg("Exit node set hash changed")
|
||||
}
|
||||
|
||||
pm.exitSet = exitSet
|
||||
pm.exitSetHash = exitSetHash
|
||||
|
||||
@@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
if !needsUpdate {
|
||||
log.Trace().
|
||||
Msg("Policy evaluation detected no changes - all hashes match")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
|
||||
pm.sshPolicyMap[node.ID()] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
@@ -403,6 +414,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil
|
||||
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
|
||||
|
||||
pm.filterRulesMap[node.ID()] = reducedFilter
|
||||
|
||||
return reducedFilter, nil
|
||||
}
|
||||
|
||||
@@ -447,7 +459,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
|
||||
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
|
||||
//
|
||||
// For global policies: returns the global matchers (same for all nodes)
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules.
|
||||
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
|
||||
if pm == nil {
|
||||
return nil, nil
|
||||
@@ -479,6 +491,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
|
||||
// Clear SSH policy map when users change to force SSH policy recomputation
|
||||
@@ -690,6 +703,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
if pm.exitSet == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) {
|
||||
return true
|
||||
}
|
||||
@@ -753,8 +767,10 @@ func (pm *PolicyManager) DebugString() string {
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
@@ -763,14 +779,17 @@ func (pm *PolicyManager) DebugString() string {
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
@@ -783,6 +802,7 @@ func (pm *PolicyManager) DebugString() string {
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Matchers:\n")
|
||||
sb.WriteString("an internal structure used to filter nodes and routes\n")
|
||||
|
||||
for _, match := range pm.matchers {
|
||||
sb.WriteString(match.DebugString())
|
||||
sb.WriteString("\n")
|
||||
@@ -790,6 +810,7 @@ func (pm *PolicyManager) DebugString() string {
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range pm.nodes.All() {
|
||||
sb.WriteString(node.String())
|
||||
sb.WriteString("\n")
|
||||
@@ -867,6 +888,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
|
||||
// Check if IPs changed (simple check - could be more sophisticated)
|
||||
oldIPs := oldNode.IPs()
|
||||
|
||||
newIPs := newNode.IPs()
|
||||
if len(oldIPs) != len(newIPs) {
|
||||
affectedUsers[newNode.User().ID()] = struct{}{}
|
||||
@@ -888,6 +910,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
for nodeID := range pm.filterRulesMap {
|
||||
// Find the user for this cached node
|
||||
var nodeUserID uint
|
||||
|
||||
found := false
|
||||
|
||||
// Check in new nodes first
|
||||
@@ -899,8 +922,10 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -913,8 +938,10 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,18 +11,16 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
func node(name, ipv4, ipv6 string, user types.User) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
IPv4: ap(ipv4),
|
||||
IPv6: ap(ipv6),
|
||||
User: ptr.To(user),
|
||||
UserID: ptr.To(user.ID),
|
||||
Hostinfo: hostinfo,
|
||||
User: new(user),
|
||||
UserID: new(user.ID),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +55,7 @@ func TestPolicyManager(t *testing.T) {
|
||||
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(
|
||||
tt.wantMatchers,
|
||||
matchers,
|
||||
@@ -77,6 +76,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"},
|
||||
}
|
||||
|
||||
//nolint:goconst // test-specific inline policy for clarity
|
||||
policy := `{
|
||||
"acls": [
|
||||
{
|
||||
@@ -88,14 +88,14 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
}`
|
||||
|
||||
initialNodes := types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
}
|
||||
|
||||
for i, n := range initialNodes {
|
||||
n.ID = types.NodeID(i + 1)
|
||||
n.ID = types.NodeID(i + 1) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
|
||||
pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice())
|
||||
@@ -107,7 +107,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, len(initialNodes), len(pm.filterRulesMap))
|
||||
require.Len(t, pm.filterRulesMap, len(initialNodes))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -118,10 +118,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "no_changes",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 0,
|
||||
description: "No changes should clear no cache entries",
|
||||
@@ -129,11 +129,11 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "node_added",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0], nil), // New node
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0]), // New node
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 2, // user1's existing nodes should be cleared
|
||||
description: "Adding a node should clear cache for that user's existing nodes",
|
||||
@@ -141,10 +141,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "node_removed",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
// user1-node2 removed
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 2, // user1's remaining node + removed node should be cleared
|
||||
description: "Removing a node should clear cache for that user's remaining nodes",
|
||||
@@ -152,10 +152,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "user_changed",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2], nil), // Changed to user3
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2]), // Changed to user3
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 3, // user1's node + user2's node + user3's nodes should be cleared
|
||||
description: "Changing a node's user should clear cache for both old and new users",
|
||||
@@ -163,10 +163,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "ip_changed",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0], nil), // IP changed
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0]), // IP changed
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 2, // user1's nodes should be cleared
|
||||
description: "Changing a node's IP should clear cache for that user's nodes",
|
||||
@@ -177,15 +177,18 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for i, n := range tt.newNodes {
|
||||
found := false
|
||||
|
||||
for _, origNode := range initialNodes {
|
||||
if n.Hostname == origNode.Hostname {
|
||||
n.ID = origNode.ID
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1)
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,16 +373,16 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) {
|
||||
|
||||
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
|
||||
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters.
|
||||
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
|
||||
users := types.Users{user1, user2}
|
||||
|
||||
// Create two nodes
|
||||
node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil)
|
||||
node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1)
|
||||
node1.ID = 1
|
||||
node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil)
|
||||
node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2)
|
||||
node2.ID = 2
|
||||
nodes := types.Nodes{node1, node2}
|
||||
|
||||
@@ -410,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
|
||||
// For node1, destinations should only be node1's IPs
|
||||
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
|
||||
|
||||
for _, rule := range filterNode1 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node1IPs, dst.IP,
|
||||
@@ -419,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
|
||||
// For node2, destinations should only be node2's IPs
|
||||
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
|
||||
|
||||
for _, rule := range filterNode2 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node2IPs, dst.IP,
|
||||
@@ -457,8 +462,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
|
||||
Hostname: "test-1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
@@ -468,8 +473,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
|
||||
Hostname: "test-2-router",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Tags: []string{"tag:node-router"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -537,8 +542,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
|
||||
Hostname: "test-1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
@@ -547,8 +552,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
|
||||
Hostname: "test-2-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
@@ -647,8 +652,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
|
||||
Hostname: "user1-node",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Tags: []string{"tag:web", "tag:internal"},
|
||||
}
|
||||
|
||||
@@ -658,8 +663,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
|
||||
Hostname: "user2-node",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
}
|
||||
|
||||
initialNodes := types.Nodes{user1Node, user2Node}
|
||||
@@ -686,8 +691,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
|
||||
Hostname: "user1-node",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Tags: []string{"tag:internal"}, // tag:web removed!
|
||||
}
|
||||
|
||||
@@ -749,8 +754,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) {
|
||||
Hostname: "admin-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
@@ -760,8 +765,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) {
|
||||
Hostname: "user1-server",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Tags: []string{"tag:server"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -832,8 +837,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) {
|
||||
Hostname: "device-a",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
@@ -843,8 +848,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) {
|
||||
Hostname: "device-b",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Tags: []string{"tag:web"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -922,8 +927,8 @@ func TestAutogroupSelfDoesNotBreakOtherUsersAccess(t *testing.T) {
|
||||
superadminDevice := &types.Node{
|
||||
ID: 1,
|
||||
Hostname: "superadmin-laptop",
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -931,8 +936,8 @@ func TestAutogroupSelfDoesNotBreakOtherUsersAccess(t *testing.T) {
|
||||
adminDevice := &types.Node{
|
||||
ID: 2,
|
||||
Hostname: "admin-laptop",
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -940,8 +945,8 @@ func TestAutogroupSelfDoesNotBreakOtherUsersAccess(t *testing.T) {
|
||||
directionDevice := &types.Node{
|
||||
ID: 3,
|
||||
Hostname: "direction-laptop",
|
||||
User: ptr.To(users[2]),
|
||||
UserID: ptr.To(users[2].ID),
|
||||
User: new(users[2]),
|
||||
UserID: new(users[2].ID),
|
||||
IPv4: ap("100.64.0.3"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -949,8 +954,8 @@ func TestAutogroupSelfDoesNotBreakOtherUsersAccess(t *testing.T) {
|
||||
commonServer := &types.Node{
|
||||
ID: 4,
|
||||
Hostname: "common-server",
|
||||
User: ptr.To(users[3]),
|
||||
UserID: ptr.To(users[3].ID),
|
||||
User: new(users[3]),
|
||||
UserID: new(users[3].ID),
|
||||
IPv4: ap("100.64.0.4"),
|
||||
Tags: []string{"tag:common"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
@@ -959,8 +964,8 @@ func TestAutogroupSelfDoesNotBreakOtherUsersAccess(t *testing.T) {
|
||||
techServer := &types.Node{
|
||||
ID: 5,
|
||||
Hostname: "tech-server",
|
||||
User: ptr.To(users[3]),
|
||||
UserID: ptr.To(users[3].ID),
|
||||
User: new(users[3]),
|
||||
UserID: new(users[3].ID),
|
||||
IPv4: ap("100.64.0.5"),
|
||||
Tags: []string{"tag:tech"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
@@ -969,8 +974,8 @@ func TestAutogroupSelfDoesNotBreakOtherUsersAccess(t *testing.T) {
|
||||
privilegedServer := &types.Node{
|
||||
ID: 6,
|
||||
Hostname: "privileged-server",
|
||||
User: ptr.To(users[3]),
|
||||
UserID: ptr.To(users[3].ID),
|
||||
User: new(users[3]),
|
||||
UserID: new(users[3].ID),
|
||||
IPv4: ap("100.64.0.6"),
|
||||
Tags: []string{"tag:privileged"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
@@ -1082,8 +1087,8 @@ func TestEmptyFilterNodesStillVisible(t *testing.T) {
|
||||
adminDevice := &types.Node{
|
||||
ID: 1,
|
||||
Hostname: "admin-laptop",
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -1093,8 +1098,8 @@ func TestEmptyFilterNodesStillVisible(t *testing.T) {
|
||||
taggedServer := &types.Node{
|
||||
ID: 2,
|
||||
Hostname: "server",
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
Tags: []string{"tag:server"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
@@ -1151,8 +1156,8 @@ func TestAutogroupSelfCombinedWithTags(t *testing.T) {
|
||||
adminLaptop := &types.Node{
|
||||
ID: 1,
|
||||
Hostname: "admin-laptop",
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -1160,8 +1165,8 @@ func TestAutogroupSelfCombinedWithTags(t *testing.T) {
|
||||
adminPhone := &types.Node{
|
||||
ID: 2,
|
||||
Hostname: "admin-phone",
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
@@ -1170,8 +1175,8 @@ func TestAutogroupSelfCombinedWithTags(t *testing.T) {
|
||||
webServer := &types.Node{
|
||||
ID: 3,
|
||||
Hostname: "web-server",
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
IPv4: ap("100.64.0.3"),
|
||||
Tags: []string{"tag:web"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
@@ -1246,8 +1251,8 @@ func TestIssue2990SameUserTaggedDevice(t *testing.T) {
|
||||
node1 := &types.Node{
|
||||
ID: 1,
|
||||
Hostname: "node1",
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
@@ -1257,8 +1262,8 @@ func TestIssue2990SameUserTaggedDevice(t *testing.T) {
|
||||
node2 := &types.Node{
|
||||
ID: 2,
|
||||
Hostname: "node2",
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
Tags: []string{"tag:admin"},
|
||||
|
||||
9936
hscontrol/policy/v2/tailscale_compat_test.go
Normal file
9936
hscontrol/policy/v2/tailscale_compat_test.go
Normal file
File diff suppressed because it is too large
Load Diff
8286
hscontrol/policy/v2/tailscale_routes_compat_test.go
Normal file
8286
hscontrol/policy/v2/tailscale_routes_compat_test.go
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,18 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Port parsing errors.
|
||||
var (
|
||||
ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port")
|
||||
ErrInputStartsWithColon = errors.New("input cannot start with a colon character")
|
||||
ErrInputEndsWithColon = errors.New("input cannot end with a colon character")
|
||||
ErrInvalidPortRangeFormat = errors.New("invalid port range format")
|
||||
ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port")
|
||||
ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard")
|
||||
ErrInvalidPortNumber = errors.New("invalid port number")
|
||||
ErrPortNumberOutOfRange = errors.New("port number out of range")
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
@@ -16,13 +28,15 @@ func splitDestinationAndPort(input string) (string, string, error) {
|
||||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
return "", "", ErrInputMissingColon
|
||||
}
|
||||
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
return "", "", ErrInputStartsWithColon
|
||||
}
|
||||
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
return "", "", ErrInputEndsWithColon
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
@@ -45,11 +59,12 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
for part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
return nil, ErrInvalidPortRangeFormat
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
@@ -63,7 +78,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
return nil, ErrPortRangeInverted
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
@@ -74,7 +89,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
}
|
||||
|
||||
if port < 1 {
|
||||
return nil, errors.New("first port must be >0, or use '*' for wildcard")
|
||||
return nil, ErrPortMustBePositive
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
@@ -88,11 +103,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
return 0, ErrInvalidPortNumber
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
return 0, ErrPortNumberOutOfRange
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -24,9 +23,9 @@ func TestParseDestinationAndPort(t *testing.T) {
|
||||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
{"invalidinput", "", "", ErrInputMissingColon},
|
||||
{":invalid", "", "", ErrInputStartsWithColon},
|
||||
{"invalid:", "", "", ErrInputEndsWithColon},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
@@ -58,9 +57,11 @@ func TestParsePort(t *testing.T) {
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
@@ -92,9 +93,11 @@ func TestParsePortRange(t *testing.T) {
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
@@ -29,7 +30,7 @@ const nodeNameContextKey = contextKey("nodeName")
|
||||
type mapSession struct {
|
||||
h *Headscale
|
||||
req tailcfg.MapRequest
|
||||
ctx context.Context
|
||||
ctx context.Context //nolint:containedctx
|
||||
capVer tailcfg.CapabilityVersion
|
||||
|
||||
cancelChMu deadlock.Mutex
|
||||
@@ -43,6 +44,8 @@ type mapSession struct {
|
||||
|
||||
node *types.Node
|
||||
w http.ResponseWriter
|
||||
|
||||
log zerolog.Logger
|
||||
}
|
||||
|
||||
func (h *Headscale) newMapSession(
|
||||
@@ -51,7 +54,7 @@ func (h *Headscale) newMapSession(
|
||||
w http.ResponseWriter,
|
||||
node *types.Node,
|
||||
) *mapSession {
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) //nolint:gosec // weak random is fine for jitter
|
||||
|
||||
return &mapSession{
|
||||
h: h,
|
||||
@@ -67,6 +70,13 @@ func (h *Headscale) newMapSession(
|
||||
|
||||
keepAlive: ka,
|
||||
keepAliveTicker: nil,
|
||||
|
||||
log: log.With().
|
||||
Str(zf.Component, "poll").
|
||||
EmbedObject(node).
|
||||
Bool(zf.OmitPeers, req.OmitPeers).
|
||||
Bool(zf.Stream, req.Stream).
|
||||
Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +142,7 @@ func (m *mapSession) serve() {
|
||||
func (m *mapSession) serveLongPoll() {
|
||||
m.beforeServeLongPoll()
|
||||
|
||||
log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected")
|
||||
m.log.Trace().Caller().Msg("long poll session started")
|
||||
|
||||
// Clean up the session when the client disconnects
|
||||
defer func() {
|
||||
@@ -152,6 +162,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
// This is not my favourite solution, but it kind of works in our eventually consistent world.
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
disconnected := true
|
||||
// Wait up to 10 seconds for the node to reconnect.
|
||||
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
|
||||
@@ -160,18 +171,19 @@ func (m *mapSession) serveLongPoll() {
|
||||
disconnected = false
|
||||
break
|
||||
}
|
||||
|
||||
<-ticker.C
|
||||
}
|
||||
|
||||
if disconnected {
|
||||
disconnectChanges, err := m.h.state.Disconnect(m.node.ID)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
m.log.Error().Caller().Err(err).Msg("failed to disconnect node")
|
||||
}
|
||||
|
||||
m.h.Change(disconnectChanges...)
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
m.log.Info().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("node has disconnected")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -193,7 +205,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
// the node to be incorrectly removed from AvailableRoutes.
|
||||
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
|
||||
if err != nil {
|
||||
m.errf(err, "failed to update node from initial MapRequest")
|
||||
m.log.Error().Caller().Err(err).Msg("failed to update node from initial MapRequest")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -205,19 +217,19 @@ func (m *mapSession) serveLongPoll() {
|
||||
// primary route selection occurs, which is critical for proper HA subnet router failover.
|
||||
connectChanges := m.h.state.Connect(m.node.ID)
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
m.log.Info().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("node has connected")
|
||||
|
||||
// TODO(kradalby): Redo the comments here
|
||||
// Add node to batcher so it can receive updates,
|
||||
// adding this before connecting it to the state ensure that
|
||||
// it does not miss any updates that might be sent in the split
|
||||
// time between the node connecting and the batcher being ready.
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { //nolint:noinlineerr
|
||||
m.log.Error().Caller().Err(err).Msg("failed to add node to batcher")
|
||||
return
|
||||
}
|
||||
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
|
||||
|
||||
m.log.Debug().Caller().Msg("node added to batcher")
|
||||
|
||||
m.h.Change(mapReqChange)
|
||||
m.h.Change(connectChanges...)
|
||||
@@ -228,40 +240,46 @@ func (m *mapSession) serveLongPoll() {
|
||||
// consume channels with update, keep alives or "batch" blocking signals
|
||||
select {
|
||||
case <-m.cancelCh:
|
||||
m.tracef("poll cancelled received")
|
||||
m.log.Trace().Caller().Msg("poll cancelled received")
|
||||
mapResponseEnded.WithLabelValues("cancelled").Inc()
|
||||
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
m.tracef("poll context done chan:%p", m.ch)
|
||||
m.log.Trace().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("poll context done")
|
||||
mapResponseEnded.WithLabelValues("done").Inc()
|
||||
|
||||
return
|
||||
|
||||
// Consume updates sent to node
|
||||
case update, ok := <-m.ch:
|
||||
m.tracef("received update from channel, ok: %t", ok)
|
||||
m.log.Trace().Caller().Bool(zf.OK, ok).Msg("received update from channel")
|
||||
|
||||
if !ok {
|
||||
m.tracef("update channel closed, streaming session is likely being replaced")
|
||||
m.log.Trace().Caller().Msg("update channel closed, streaming session is likely being replaced")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.writeMap(update); err != nil {
|
||||
m.errf(err, "cannot write update to client")
|
||||
err := m.writeMap(update)
|
||||
if err != nil {
|
||||
m.log.Error().Caller().Err(err).Msg("cannot write update to client")
|
||||
return
|
||||
}
|
||||
|
||||
m.tracef("update sent")
|
||||
m.log.Trace().Caller().Msg("update sent")
|
||||
m.resetKeepAlive()
|
||||
|
||||
case <-m.keepAliveTicker.C:
|
||||
if err := m.writeMap(&keepAlive); err != nil {
|
||||
m.errf(err, "cannot write keep alive")
|
||||
err := m.writeMap(&keepAlive)
|
||||
if err != nil {
|
||||
m.log.Error().Caller().Err(err).Msg("cannot write keep alive")
|
||||
return
|
||||
}
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
@@ -282,7 +300,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
data := make([]byte, reservedResponseHeaderSize, reservedResponseHeaderSize+len(jsonBody))
|
||||
//nolint:gosec // G115: JSON response size will not exceed uint32 max
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||
data = append(data, jsonBody...)
|
||||
@@ -298,19 +316,17 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
if f, ok := m.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
} else {
|
||||
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
|
||||
m.log.Error().Caller().Msg("responseWriter does not implement http.Flusher, cannot flush")
|
||||
}
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
m.log.Trace().
|
||||
Caller().
|
||||
Str("node.name", m.node.Hostname).
|
||||
Uint64("node.id", m.node.ID.Uint64()).
|
||||
Str("chan", fmt.Sprintf("%p", m.ch)).
|
||||
Str(zf.Chan, fmt.Sprintf("%p", m.ch)).
|
||||
TimeDiff("timeSpent", time.Now(), startWrite).
|
||||
Str("machine.key", m.node.MachineKey.String()).
|
||||
Str(zf.MachineKey, m.node.MachineKey.String()).
|
||||
Bool("keepalive", msg.KeepAlive).
|
||||
Msgf("finished writing mapresp to node chan(%p)", m.ch)
|
||||
Msg("finished writing mapresp to node")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -318,23 +334,3 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
var keepAlive = tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// logf adds common mapSession context to a zerolog event.
|
||||
func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event {
|
||||
return event.
|
||||
Bool("omitPeers", m.req.OmitPeers).
|
||||
Bool("stream", m.req.Stream).
|
||||
Uint64("node.id", m.node.ID.Uint64()).
|
||||
Str("node.name", m.node.Hostname)
|
||||
}
|
||||
|
||||
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
|
||||
func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) }
|
||||
|
||||
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
|
||||
func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) }
|
||||
|
||||
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
|
||||
func (m *mapSession) errf(err error, msg string, a ...any) {
|
||||
m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user