mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-19 14:17:48 +01:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d15ec28799 | ||
|
|
eccf64eb58 | ||
|
|
43afeedde2 | ||
|
|
73613d7f53 | ||
|
|
30d18575be | ||
|
|
70f8141abd | ||
|
|
82958835ce | ||
|
|
9c3a3c5837 | ||
|
|
faf55f5e8f | ||
|
|
e3323b65e5 | ||
|
|
8f60b819ec | ||
|
|
c29bcd2eaf | ||
|
|
890a044ef6 | ||
|
|
8028fa5483 | ||
|
|
a7f981e30e | ||
|
|
e0d8c3c877 | ||
|
|
c1b468f9f4 | ||
|
|
900f4b7b75 | ||
|
|
64f23136a2 | ||
|
|
0f6d312ada | ||
|
|
20dff82f95 | ||
|
|
31c4331a91 | ||
|
|
ce580f8245 | ||
|
|
bfb6fd80df | ||
|
|
3acce2da87 | ||
|
|
4a9a329339 | ||
|
|
dd16567c52 | ||
|
|
e0a436cefc | ||
|
|
53cdeff129 | ||
|
|
7148a690d0 | ||
|
|
4e73133b9f | ||
|
|
4f8724151e | ||
|
|
91730e2a1d | ||
|
|
b5090a01ec | ||
|
|
27f5641341 | ||
|
|
cf3d30b6f6 | ||
|
|
58020696fe | ||
|
|
e44b402fe4 | ||
|
|
835b7eb960 | ||
|
|
95b1fd636e | ||
|
|
834ac27779 | ||
|
|
4a4032a4b0 | ||
|
|
29aa08df0e | ||
|
|
0b1727c337 | ||
|
|
08fe2e4d6c | ||
|
|
cb29cade46 | ||
|
|
f27298c759 | ||
|
|
8baa14ef4a | ||
|
|
ebdbe03639 | ||
|
|
f735502eae | ||
|
|
53d17aa321 | ||
|
|
14f833bdb9 | ||
|
|
9e50071df9 | ||
|
|
c907b0d323 |
6
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
6
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
@@ -6,8 +6,7 @@ body:
|
|||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
attributes:
|
attributes:
|
||||||
label: Is this a support request?
|
label: Is this a support request?
|
||||||
description:
|
description: This issue tracker is for bugs and feature requests only. If you need
|
||||||
This issue tracker is for bugs and feature requests only. If you need
|
|
||||||
help, please use ask in our Discord community
|
help, please use ask in our Discord community
|
||||||
options:
|
options:
|
||||||
- label: This is not a support request
|
- label: This is not a support request
|
||||||
@@ -15,8 +14,7 @@ body:
|
|||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
attributes:
|
attributes:
|
||||||
label: Is there an existing issue for this?
|
label: Is there an existing issue for this?
|
||||||
description:
|
description: Please search to see if an issue already exists for the bug you
|
||||||
Please search to see if an issue already exists for the bug you
|
|
||||||
encountered.
|
encountered.
|
||||||
options:
|
options:
|
||||||
- label: I have searched the existing issues
|
- label: I have searched the existing issues
|
||||||
|
|||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
8
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -3,9 +3,9 @@ blank_issues_enabled: false
|
|||||||
|
|
||||||
# Contact links
|
# Contact links
|
||||||
contact_links:
|
contact_links:
|
||||||
- name: "headscale usage documentation"
|
|
||||||
url: "https://github.com/juanfont/headscale/blob/main/docs"
|
|
||||||
about: "Find documentation about how to configure and run headscale."
|
|
||||||
- name: "headscale Discord community"
|
- name: "headscale Discord community"
|
||||||
url: "https://discord.gg/xGj2TuqyxY"
|
url: "https://discord.gg/c84AZQhmpx"
|
||||||
about: "Please ask and answer questions about usage of headscale here."
|
about: "Please ask and answer questions about usage of headscale here."
|
||||||
|
- name: "headscale usage documentation"
|
||||||
|
url: "https://headscale.net/"
|
||||||
|
about: "Find documentation about how to configure and run headscale."
|
||||||
|
|||||||
80
.github/label-response/needs-more-info.md
vendored
Normal file
80
.github/label-response/needs-more-info.md
vendored
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
Thank you for taking the time to report this issue.
|
||||||
|
|
||||||
|
To help us investigate and resolve this, we need more information. Please provide the following:
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Most issues turn out to be configuration errors rather than bugs. We encourage you to discuss your problem in our [Discord community](https://discord.gg/c84AZQhmpx) **before** opening an issue. The community can often help identify misconfigurations quickly, saving everyone time.
|
||||||
|
|
||||||
|
## Required Information
|
||||||
|
|
||||||
|
### Environment Details
|
||||||
|
|
||||||
|
- **Headscale version**: (run `headscale version`)
|
||||||
|
- **Tailscale client version**: (run `tailscale version`)
|
||||||
|
- **Operating System**: (e.g., Ubuntu 24.04, macOS 14, Windows 11)
|
||||||
|
- **Deployment method**: (binary, Docker, Kubernetes, etc.)
|
||||||
|
- **Reverse proxy**: (if applicable: nginx, Traefik, Caddy, etc. - include configuration)
|
||||||
|
|
||||||
|
### Debug Information
|
||||||
|
|
||||||
|
Please follow our [Debugging and Troubleshooting Guide](https://headscale.net/stable/ref/debug/) and provide:
|
||||||
|
|
||||||
|
1. **Client netmap dump** (from affected Tailscale client):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tailscale debug netmap > netmap.json
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Client status dump** (from affected Tailscale client):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tailscale status --json > status.json
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Tailscale client logs** (if experiencing client issues):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tailscale debug daemon-logs
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> We need logs from **multiple nodes** to understand the full picture:
|
||||||
|
>
|
||||||
|
> - The node(s) initiating connections
|
||||||
|
> - The node(s) being connected to
|
||||||
|
>
|
||||||
|
> Without logs from both sides, we cannot diagnose connectivity issues.
|
||||||
|
|
||||||
|
4. **Headscale server logs** with `log.level: trace` enabled
|
||||||
|
|
||||||
|
5. **Headscale configuration** (with sensitive values redacted - see rules below)
|
||||||
|
|
||||||
|
6. **ACL/Policy configuration** (if using ACLs)
|
||||||
|
|
||||||
|
7. **Proxy/Docker configuration** (if applicable - nginx.conf, docker-compose.yml, Traefik config, etc.)
|
||||||
|
|
||||||
|
## Formatting Requirements
|
||||||
|
|
||||||
|
- **Attach long files** - Do not paste large logs or configurations inline. Use GitHub file attachments or GitHub Gists.
|
||||||
|
- **Use proper Markdown** - Format code blocks, logs, and configurations with appropriate syntax highlighting.
|
||||||
|
- **Structure your response** - Use the headings above to organize your information clearly.
|
||||||
|
|
||||||
|
## Redaction Rules
|
||||||
|
|
||||||
|
> [!CAUTION]
|
||||||
|
> **Replace, do not remove.** Removing information makes debugging impossible.
|
||||||
|
|
||||||
|
When redacting sensitive information:
|
||||||
|
|
||||||
|
- ✅ **Replace consistently** - If you change `alice@company.com` to `user1@example.com`, use `user1@example.com` everywhere (logs, config, policy, etc.)
|
||||||
|
- ✅ **Use meaningful placeholders** - `user1@example.com`, `bob@example.com`, `my-secret-key` are acceptable
|
||||||
|
- ❌ **Never remove information** - Gaps in data prevent us from correlating events across logs
|
||||||
|
- ❌ **Never redact IP addresses** - We need the actual IPs to trace network paths and identify issues
|
||||||
|
|
||||||
|
**If redaction rules are not followed, we will be unable to debug the issue and will have to close it.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Note:** This issue will be automatically closed in 3 days if no additional information is provided. Once you reply with the requested information, the `needs-more-info` label will be removed automatically.
|
||||||
|
|
||||||
|
If you need help gathering this information, please visit our [Discord community](https://discord.gg/c84AZQhmpx).
|
||||||
15
.github/label-response/support-request.md
vendored
Normal file
15
.github/label-response/support-request.md
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
Thank you for reaching out.
|
||||||
|
|
||||||
|
This issue tracker is used for **bug reports and feature requests** only. Your question appears to be a support or configuration question rather than a bug report.
|
||||||
|
|
||||||
|
For help with setup, configuration, or general questions, please visit our [Discord community](https://discord.gg/c84AZQhmpx) where the community and maintainers can assist you in real-time.
|
||||||
|
|
||||||
|
**Before posting in Discord, please check:**
|
||||||
|
|
||||||
|
- [Documentation](https://headscale.net/)
|
||||||
|
- [FAQ](https://headscale.net/stable/faq/)
|
||||||
|
- [Debugging and Troubleshooting Guide](https://headscale.net/stable/ref/debug/)
|
||||||
|
|
||||||
|
If after troubleshooting you determine this is actually a bug, please open a new issue with the required debug information from the troubleshooting guide.
|
||||||
|
|
||||||
|
This issue has been automatically closed.
|
||||||
18
.github/workflows/integration-test-template.yml
vendored
18
.github/workflows/integration-test-template.yml
vendored
@@ -67,6 +67,24 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: postgres-image
|
name: postgres-image
|
||||||
path: /tmp/artifacts
|
path: /tmp/artifacts
|
||||||
|
- name: Pin Docker to v28 (avoid v29 breaking changes)
|
||||||
|
run: |
|
||||||
|
# Docker 29 breaks docker build via Go client libraries and
|
||||||
|
# docker load/save with certain tarball formats.
|
||||||
|
# Pin to Docker 28.x until our tooling is updated.
|
||||||
|
# https://github.com/actions/runner-images/issues/13474
|
||||||
|
sudo install -m 0755 -d /etc/apt/keyrings
|
||||||
|
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
|
||||||
|
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||||
|
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
|
||||||
|
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
|
||||||
|
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||||
|
sudo apt-get update -qq
|
||||||
|
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
|
||||||
|
sudo apt-get install -y --allow-downgrades \
|
||||||
|
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
|
||||||
|
sudo systemctl restart docker
|
||||||
|
docker version
|
||||||
- name: Load Docker images, Go cache, and prepare binary
|
- name: Load Docker images, Go cache, and prepare binary
|
||||||
run: |
|
run: |
|
||||||
gunzip -c /tmp/artifacts/headscale-image.tar.gz | docker load
|
gunzip -c /tmp/artifacts/headscale-image.tar.gz | docker load
|
||||||
|
|||||||
28
.github/workflows/needs-more-info-comment.yml
vendored
Normal file
28
.github/workflows/needs-more-info-comment.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Needs More Info - Post Comment
|
||||||
|
|
||||||
|
on:
|
||||||
|
issues:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
post-comment:
|
||||||
|
if: >-
|
||||||
|
github.event.label.name == 'needs-more-info' &&
|
||||||
|
github.repository == 'juanfont/headscale'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
with:
|
||||||
|
sparse-checkout: .github/label-response/needs-more-info.md
|
||||||
|
sparse-checkout-cone-mode: false
|
||||||
|
|
||||||
|
- name: Post instruction comment
|
||||||
|
run: gh issue comment "$NUMBER" --body-file .github/label-response/needs-more-info.md
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GH_REPO: ${{ github.repository }}
|
||||||
|
NUMBER: ${{ github.event.issue.number }}
|
||||||
98
.github/workflows/needs-more-info-timer.yml
vendored
Normal file
98
.github/workflows/needs-more-info-timer.yml
vendored
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
name: Needs More Info - Timer
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 0 * * *" # Daily at midnight UTC
|
||||||
|
issue_comment:
|
||||||
|
types: [created]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# When a non-bot user comments on a needs-more-info issue, remove the label.
|
||||||
|
remove-label-on-response:
|
||||||
|
if: >-
|
||||||
|
github.repository == 'juanfont/headscale' &&
|
||||||
|
github.event_name == 'issue_comment' &&
|
||||||
|
github.event.comment.user.type != 'Bot' &&
|
||||||
|
contains(github.event.issue.labels.*.name, 'needs-more-info')
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
steps:
|
||||||
|
- name: Remove needs-more-info label
|
||||||
|
run: gh issue edit "$NUMBER" --remove-label needs-more-info
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GH_REPO: ${{ github.repository }}
|
||||||
|
NUMBER: ${{ github.event.issue.number }}
|
||||||
|
|
||||||
|
# On schedule, close issues that have had no human response for 3 days.
|
||||||
|
close-stale:
|
||||||
|
if: >-
|
||||||
|
github.repository == 'juanfont/headscale' &&
|
||||||
|
github.event_name != 'issue_comment'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
steps:
|
||||||
|
- uses: hustcer/setup-nu@920172d92eb04671776f3ba69d605d3b09351c30 # v3.22
|
||||||
|
with:
|
||||||
|
version: "*"
|
||||||
|
|
||||||
|
- name: Close stale needs-more-info issues
|
||||||
|
shell: nu {0}
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GH_REPO: ${{ github.repository }}
|
||||||
|
run: |
|
||||||
|
let issues = (gh issue list
|
||||||
|
--repo $env.GH_REPO
|
||||||
|
--label "needs-more-info"
|
||||||
|
--state open
|
||||||
|
--json number
|
||||||
|
| from json)
|
||||||
|
|
||||||
|
for issue in $issues {
|
||||||
|
let number = $issue.number
|
||||||
|
print $"Checking issue #($number)"
|
||||||
|
|
||||||
|
# Find when needs-more-info was last added
|
||||||
|
let events = (gh api $"repos/($env.GH_REPO)/issues/($number)/events"
|
||||||
|
--paginate | from json | flatten)
|
||||||
|
let label_event = ($events
|
||||||
|
| where event == "labeled" and label.name == "needs-more-info"
|
||||||
|
| last)
|
||||||
|
let label_added_at = ($label_event.created_at | into datetime)
|
||||||
|
|
||||||
|
# Check for non-bot comments after the label was added
|
||||||
|
let comments = (gh api $"repos/($env.GH_REPO)/issues/($number)/comments"
|
||||||
|
--paginate | from json | flatten)
|
||||||
|
let human_responses = ($comments
|
||||||
|
| where user.type != "Bot"
|
||||||
|
| where { ($in.created_at | into datetime) > $label_added_at })
|
||||||
|
|
||||||
|
if ($human_responses | length) > 0 {
|
||||||
|
print $" Human responded, removing label"
|
||||||
|
gh issue edit $number --repo $env.GH_REPO --remove-label needs-more-info
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if 3 days have passed
|
||||||
|
let elapsed = (date now) - $label_added_at
|
||||||
|
if $elapsed < 3day {
|
||||||
|
print $" Only ($elapsed | format duration day) elapsed, skipping"
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
print $" No response for ($elapsed | format duration day), closing"
|
||||||
|
let message = [
|
||||||
|
"This issue has been automatically closed because no additional information was provided within 3 days."
|
||||||
|
""
|
||||||
|
"If you have the requested information, please open a new issue and include the debug information requested above."
|
||||||
|
""
|
||||||
|
"Thank you for your understanding."
|
||||||
|
] | str join "\n"
|
||||||
|
gh issue comment $number --repo $env.GH_REPO --body $message
|
||||||
|
gh issue close $number --repo $env.GH_REPO --reason "not planned"
|
||||||
|
gh issue edit $number --repo $env.GH_REPO --remove-label needs-more-info
|
||||||
|
}
|
||||||
19
.github/workflows/release.yml
vendored
19
.github/workflows/release.yml
vendored
@@ -17,6 +17,25 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Pin Docker to v28 (avoid v29 breaking changes)
|
||||||
|
run: |
|
||||||
|
# Docker 29 breaks docker build via Go client libraries and
|
||||||
|
# docker load/save with certain tarball formats.
|
||||||
|
# Pin to Docker 28.x until our tooling is updated.
|
||||||
|
# https://github.com/actions/runner-images/issues/13474
|
||||||
|
sudo install -m 0755 -d /etc/apt/keyrings
|
||||||
|
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
|
||||||
|
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||||
|
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
|
||||||
|
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
|
||||||
|
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||||
|
sudo apt-get update -qq
|
||||||
|
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
|
||||||
|
sudo apt-get install -y --allow-downgrades \
|
||||||
|
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
|
||||||
|
sudo systemctl restart docker
|
||||||
|
docker version
|
||||||
|
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -23,5 +23,5 @@ jobs:
|
|||||||
since being marked as stale."
|
since being marked as stale."
|
||||||
days-before-pr-stale: -1
|
days-before-pr-stale: -1
|
||||||
days-before-pr-close: -1
|
days-before-pr-close: -1
|
||||||
exempt-issue-labels: "no-stale-bot"
|
exempt-issue-labels: "no-stale-bot,needs-more-info"
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
30
.github/workflows/support-request.yml
vendored
Normal file
30
.github/workflows/support-request.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: Support Request - Close Issue
|
||||||
|
|
||||||
|
on:
|
||||||
|
issues:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-support-request:
|
||||||
|
if: >-
|
||||||
|
github.event.label.name == 'support-request' &&
|
||||||
|
github.repository == 'juanfont/headscale'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
with:
|
||||||
|
sparse-checkout: .github/label-response/support-request.md
|
||||||
|
sparse-checkout-cone-mode: false
|
||||||
|
|
||||||
|
- name: Post comment and close issue
|
||||||
|
run: |
|
||||||
|
gh issue comment "$NUMBER" --body-file .github/label-response/support-request.md
|
||||||
|
gh issue close "$NUMBER" --reason "not planned"
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GH_REPO: ${{ github.repository }}
|
||||||
|
NUMBER: ${{ github.event.issue.number }}
|
||||||
37
.github/workflows/test-integration.yaml
vendored
37
.github/workflows/test-integration.yaml
vendored
@@ -69,6 +69,25 @@ jobs:
|
|||||||
name: go-cache
|
name: go-cache
|
||||||
path: go-cache.tar.gz
|
path: go-cache.tar.gz
|
||||||
retention-days: 10
|
retention-days: 10
|
||||||
|
- name: Pin Docker to v28 (avoid v29 breaking changes)
|
||||||
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
|
run: |
|
||||||
|
# Docker 29 breaks docker build via Go client libraries and
|
||||||
|
# docker load/save with certain tarball formats.
|
||||||
|
# Pin to Docker 28.x until our tooling is updated.
|
||||||
|
# https://github.com/actions/runner-images/issues/13474
|
||||||
|
sudo install -m 0755 -d /etc/apt/keyrings
|
||||||
|
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
|
||||||
|
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||||
|
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
|
||||||
|
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
|
||||||
|
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||||
|
sudo apt-get update -qq
|
||||||
|
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
|
||||||
|
sudo apt-get install -y --allow-downgrades \
|
||||||
|
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
|
||||||
|
sudo systemctl restart docker
|
||||||
|
docker version
|
||||||
- name: Build headscale image
|
- name: Build headscale image
|
||||||
if: steps.changed-files.outputs.files == 'true'
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
run: |
|
run: |
|
||||||
@@ -104,6 +123,24 @@ jobs:
|
|||||||
needs: build
|
needs: build
|
||||||
if: needs.build.outputs.files-changed == 'true'
|
if: needs.build.outputs.files-changed == 'true'
|
||||||
steps:
|
steps:
|
||||||
|
- name: Pin Docker to v28 (avoid v29 breaking changes)
|
||||||
|
run: |
|
||||||
|
# Docker 29 breaks docker build via Go client libraries and
|
||||||
|
# docker load/save with certain tarball formats.
|
||||||
|
# Pin to Docker 28.x until our tooling is updated.
|
||||||
|
# https://github.com/actions/runner-images/issues/13474
|
||||||
|
sudo install -m 0755 -d /etc/apt/keyrings
|
||||||
|
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
|
||||||
|
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||||
|
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
|
||||||
|
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
|
||||||
|
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||||
|
sudo apt-get update -qq
|
||||||
|
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
|
||||||
|
sudo apt-get install -y --allow-downgrades \
|
||||||
|
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
|
||||||
|
sudo systemctl restart docker
|
||||||
|
docker version
|
||||||
- name: Pull and save postgres image
|
- name: Pull and save postgres image
|
||||||
run: |
|
run: |
|
||||||
docker pull postgres:latest
|
docker pull postgres:latest
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ linters:
|
|||||||
- lll
|
- lll
|
||||||
- maintidx
|
- maintidx
|
||||||
- makezero
|
- makezero
|
||||||
|
- mnd
|
||||||
- musttag
|
- musttag
|
||||||
- nestif
|
- nestif
|
||||||
- nolintlint
|
- nolintlint
|
||||||
@@ -37,6 +38,23 @@ linters:
|
|||||||
time.Sleep is forbidden.
|
time.Sleep is forbidden.
|
||||||
In tests: use assert.EventuallyWithT for polling/waiting patterns.
|
In tests: use assert.EventuallyWithT for polling/waiting patterns.
|
||||||
In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives.
|
In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives.
|
||||||
|
# Forbid inline string literals in zerolog field methods - use zf.* constants
|
||||||
|
- pattern: '\.(Str|Int|Int8|Int16|Int32|Int64|Uint|Uint8|Uint16|Uint32|Uint64|Float32|Float64|Bool|Dur|Time|TimeDiff|Strs|Ints|Uints|Floats|Bools|Any|Interface)\("[^"]+"'
|
||||||
|
msg: >-
|
||||||
|
Use zf.* constants for zerolog field names instead of string literals.
|
||||||
|
Import "github.com/juanfont/headscale/hscontrol/util/zlog/zf" and use
|
||||||
|
constants like zf.NodeID, zf.UserName, etc. Add new constants to
|
||||||
|
hscontrol/util/zlog/zf/fields.go if needed.
|
||||||
|
# Forbid ptr.To - use Go 1.26 new(expr) instead
|
||||||
|
- pattern: 'ptr\.To\('
|
||||||
|
msg: >-
|
||||||
|
ptr.To is forbidden. Use Go 1.26's new(expr) syntax instead.
|
||||||
|
Example: ptr.To(value) → new(value)
|
||||||
|
# Forbid tsaddr.SortPrefixes - use slices.SortFunc with netip.Prefix.Compare
|
||||||
|
- pattern: 'tsaddr\.SortPrefixes'
|
||||||
|
msg: >-
|
||||||
|
tsaddr.SortPrefixes is forbidden. Use Go 1.26's netip.Prefix.Compare instead.
|
||||||
|
Example: slices.SortFunc(prefixes, netip.Prefix.Compare)
|
||||||
analyze-types: true
|
analyze-types: true
|
||||||
gocritic:
|
gocritic:
|
||||||
disabled-checks:
|
disabled-checks:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
version: 2
|
version: 2
|
||||||
before:
|
before:
|
||||||
hooks:
|
hooks:
|
||||||
- go mod tidy -compat=1.25
|
- go mod tidy -compat=1.26
|
||||||
- go mod vendor
|
- go mod vendor
|
||||||
|
|
||||||
release:
|
release:
|
||||||
|
|||||||
@@ -43,26 +43,12 @@ repos:
|
|||||||
entry: prettier --write --list-different
|
entry: prettier --write --list-different
|
||||||
language: system
|
language: system
|
||||||
exclude: ^docs/
|
exclude: ^docs/
|
||||||
types_or:
|
types_or: [javascript, jsx, ts, tsx, yaml, json, toml, html, css, scss, sass, markdown]
|
||||||
[
|
|
||||||
javascript,
|
|
||||||
jsx,
|
|
||||||
ts,
|
|
||||||
tsx,
|
|
||||||
yaml,
|
|
||||||
json,
|
|
||||||
toml,
|
|
||||||
html,
|
|
||||||
css,
|
|
||||||
scss,
|
|
||||||
sass,
|
|
||||||
markdown,
|
|
||||||
]
|
|
||||||
|
|
||||||
# golangci-lint for Go code quality
|
# golangci-lint for Go code quality
|
||||||
- id: golangci-lint
|
- id: golangci-lint
|
||||||
name: golangci-lint
|
name: golangci-lint
|
||||||
entry: nix develop --command golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
|
entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
|
||||||
language: system
|
language: system
|
||||||
types: [go]
|
types: [go]
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
32
CHANGELOG.md
32
CHANGELOG.md
@@ -2,6 +2,38 @@
|
|||||||
|
|
||||||
## 0.29.0 (202x-xx-xx)
|
## 0.29.0 (202x-xx-xx)
|
||||||
|
|
||||||
|
**Minimum supported Tailscale client version: v1.76.0**
|
||||||
|
|
||||||
|
### Tailscale ACL compatibility improvements
|
||||||
|
|
||||||
|
Extensive test cases were systematically generated using Tailscale clients and the official SaaS
|
||||||
|
to understand how the packet filter should be generated. We discovered a few differences, but
|
||||||
|
overall our implementation was very close.
|
||||||
|
[#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
|
||||||
|
### BREAKING
|
||||||
|
|
||||||
|
- **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
- This better matches Tailscale's security model where `*` means "any node in the tailnet" rather than "any IP address"
|
||||||
|
- Policies relying on wildcard to match non-Tailscale IPs will need to use explicit CIDR ranges instead
|
||||||
|
- **Note**: Users with non-standard IP ranges configured in `prefixes.ipv4` or `prefixes.ipv6` (which is unsupported and produces a warning) will need to explicitly specify their CIDR ranges in ACL rules instead of using `*`
|
||||||
|
- **ACL Policy**: Validate autogroup:self source restrictions matching Tailscale behavior - tags, hosts, and IPs are rejected as sources for autogroup:self destinations [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
- Policies using tags, hosts, or IP addresses as sources for autogroup:self destinations will now fail validation
|
||||||
|
- **Upgrade path**: Headscale now enforces a strict version upgrade path [#3083](https://github.com/juanfont/headscale/pull/3083)
|
||||||
|
- Skipping minor versions (e.g. 0.27 → 0.29) is blocked; upgrade one minor version at a time
|
||||||
|
- Downgrading to a previous minor version is blocked
|
||||||
|
- Patch version changes within the same minor are always allowed
|
||||||
|
- **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
- Previously, `proto:icmp` included both ICMPv4 and ICMPv6
|
||||||
|
- Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6
|
||||||
|
|
||||||
|
### Changes
|
||||||
|
|
||||||
|
- **ACL Policy**: Add ICMP and IPv6-ICMP protocols to default filter rules when no protocol is specified [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
- **ACL Policy**: Fix autogroup:self handling for tagged nodes - tagged nodes no longer incorrectly receive autogroup:self filter rules [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
- **ACL Policy**: Use CIDR format for autogroup:self destination IPs matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
- **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036)
|
||||||
|
|
||||||
## 0.28.0 (2026-02-04)
|
## 0.28.0 (2026-02-04)
|
||||||
|
|
||||||
**Minimum supported Tailscale client version: v1.74.0**
|
**Minimum supported Tailscale client version: v1.74.0**
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# For testing purposes only
|
# For testing purposes only
|
||||||
|
|
||||||
FROM golang:alpine AS build-env
|
FROM golang:1.26.0-alpine AS build-env
|
||||||
|
|
||||||
WORKDIR /go/src
|
WORKDIR /go/src
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# and are in no way endorsed by Headscale's maintainers as an
|
# and are in no way endorsed by Headscale's maintainers as an
|
||||||
# official nor supported release or distribution.
|
# official nor supported release or distribution.
|
||||||
|
|
||||||
FROM docker.io/golang:1.25-trixie AS builder
|
FROM docker.io/golang:1.26.0-trixie AS builder
|
||||||
ARG VERSION=dev
|
ARG VERSION=dev
|
||||||
ENV GOPATH /go
|
ENV GOPATH /go
|
||||||
WORKDIR /go/src/headscale
|
WORKDIR /go/src/headscale
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
# This Dockerfile is more or less lifted from tailscale/tailscale
|
# This Dockerfile is more or less lifted from tailscale/tailscale
|
||||||
# to ensure a similar build process when testing the HEAD of tailscale.
|
# to ensure a similar build process when testing the HEAD of tailscale.
|
||||||
|
|
||||||
FROM golang:1.25-alpine AS build-env
|
FROM golang:1.26.0-alpine AS build-env
|
||||||
|
|
||||||
WORKDIR /go/src
|
WORKDIR /go/src
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ For NixOS users, a module is available in [`nix/`](./nix/).
|
|||||||
|
|
||||||
## Talks
|
## Talks
|
||||||
|
|
||||||
|
- Fosdem 2026 (video): [Headscale & Tailscale: The complementary open source clone](https://fosdem.org/2026/schedule/event/KYQ3LL-headscale-the-complementary-open-source-clone/)
|
||||||
|
- presented by Kristoffer Dalby
|
||||||
- Fosdem 2023 (video): [Headscale: How we are using integration testing to reimplement Tailscale](https://fosdem.org/2023/schedule/event/goheadscale/)
|
- Fosdem 2023 (video): [Headscale: How we are using integration testing to reimplement Tailscale](https://fosdem.org/2023/schedule/event/goheadscale/)
|
||||||
- presented by Juan Font Alonso and Kristoffer Dalby
|
- presented by Juan Font Alonso and Kristoffer Dalby
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// 90 days.
|
// DefaultAPIKeyExpiry is 90 days.
|
||||||
DefaultAPIKeyExpiry = "90d"
|
DefaultAPIKeyExpiry = "90d"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -71,6 +71,7 @@ var listAPIKeys = &cobra.Command{
|
|||||||
tableData := pterm.TableData{
|
tableData := pterm.TableData{
|
||||||
{"ID", "Prefix", "Expiration", "Created"},
|
{"ID", "Prefix", "Expiration", "Created"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range response.GetApiKeys() {
|
for _, key := range response.GetApiKeys() {
|
||||||
expiration := "-"
|
expiration := "-"
|
||||||
|
|
||||||
@@ -84,8 +85,8 @@ var listAPIKeys = &cobra.Command{
|
|||||||
expiration,
|
expiration,
|
||||||
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ var configTestCmd = &cobra.Command{
|
|||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
_, err := newHeadscaleServerWithConfig()
|
_, err := newHeadscaleServerWithConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Caller().Err(err).Msg("Error initializing")
|
log.Fatal().Caller().Err(err).Msg("error initializing")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,10 +19,12 @@ func init() {
|
|||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|
||||||
createNodeCmd.Flags().StringP("name", "", "", "Name")
|
createNodeCmd.Flags().StringP("name", "", "", "Name")
|
||||||
|
|
||||||
err := createNodeCmd.MarkFlagRequired("name")
|
err := createNodeCmd.MarkFlagRequired("name")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("")
|
log.Fatal().Err(err).Msg("")
|
||||||
}
|
}
|
||||||
|
|
||||||
createNodeCmd.Flags().StringP("user", "u", "", "User")
|
createNodeCmd.Flags().StringP("user", "u", "", "User")
|
||||||
|
|
||||||
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
||||||
@@ -34,11 +36,14 @@ func init() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("")
|
log.Fatal().Err(err).Msg("")
|
||||||
}
|
}
|
||||||
|
|
||||||
createNodeCmd.Flags().StringP("key", "k", "", "Key")
|
createNodeCmd.Flags().StringP("key", "k", "", "Key")
|
||||||
|
|
||||||
err = createNodeCmd.MarkFlagRequired("key")
|
err = createNodeCmd.MarkFlagRequired("key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("")
|
log.Fatal().Err(err).Msg("")
|
||||||
}
|
}
|
||||||
|
|
||||||
createNodeCmd.Flags().
|
createNodeCmd.Flags().
|
||||||
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ var healthCmd = &cobra.Command{
|
|||||||
Long: "Check the health of the Headscale server. This command will return an exit code of 0 if the server is healthy, or 1 if it is not.",
|
Long: "Check the health of the Headscale server. This command will return an exit code of 0 if the server is healthy, or 1 if it is not.",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
"github.com/oauth2-proxy/mockoidc"
|
"github.com/oauth2-proxy/mockoidc"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -19,6 +20,7 @@ const (
|
|||||||
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
||||||
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
||||||
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
|
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
|
||||||
|
errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined")
|
||||||
refreshTTL = 60 * time.Minute
|
refreshTTL = 60 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,7 +37,7 @@ var mockOidcCmd = &cobra.Command{
|
|||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
err := mockOIDC()
|
err := mockOIDC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("Error running mock OIDC server")
|
log.Error().Err(err).Msgf("error running mock OIDC server")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -46,41 +48,47 @@ func mockOIDC() error {
|
|||||||
if clientID == "" {
|
if clientID == "" {
|
||||||
return errMockOidcClientIDNotDefined
|
return errMockOidcClientIDNotDefined
|
||||||
}
|
}
|
||||||
|
|
||||||
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
|
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
|
||||||
if clientSecret == "" {
|
if clientSecret == "" {
|
||||||
return errMockOidcClientSecretNotDefined
|
return errMockOidcClientSecretNotDefined
|
||||||
}
|
}
|
||||||
|
|
||||||
addrStr := os.Getenv("MOCKOIDC_ADDR")
|
addrStr := os.Getenv("MOCKOIDC_ADDR")
|
||||||
if addrStr == "" {
|
if addrStr == "" {
|
||||||
return errMockOidcPortNotDefined
|
return errMockOidcPortNotDefined
|
||||||
}
|
}
|
||||||
|
|
||||||
portStr := os.Getenv("MOCKOIDC_PORT")
|
portStr := os.Getenv("MOCKOIDC_PORT")
|
||||||
if portStr == "" {
|
if portStr == "" {
|
||||||
return errMockOidcPortNotDefined
|
return errMockOidcPortNotDefined
|
||||||
}
|
}
|
||||||
|
|
||||||
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
|
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
|
||||||
if accessTTLOverride != "" {
|
if accessTTLOverride != "" {
|
||||||
newTTL, err := time.ParseDuration(accessTTLOverride)
|
newTTL, err := time.ParseDuration(accessTTLOverride)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
accessTTL = newTTL
|
accessTTL = newTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
userStr := os.Getenv("MOCKOIDC_USERS")
|
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||||
if userStr == "" {
|
if userStr == "" {
|
||||||
return errors.New("MOCKOIDC_USERS not defined")
|
return errMockOidcUsersNotDefined
|
||||||
}
|
}
|
||||||
|
|
||||||
var users []mockoidc.MockUser
|
var users []mockoidc.MockUser
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(userStr), &users)
|
err := json.Unmarshal([]byte(userStr), &users)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshalling users: %w", err)
|
return fmt.Errorf("unmarshalling users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Interface("users", users).Msg("loading users from JSON")
|
log.Info().Interface(zf.Users, users).Msg("loading users from JSON")
|
||||||
|
|
||||||
log.Info().Msgf("Access token TTL: %s", accessTTL)
|
log.Info().Msgf("access token TTL: %s", accessTTL)
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -92,7 +100,7 @@ func mockOIDC() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
listener, err := new(net.ListenConfig).Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -101,8 +109,10 @@ func mockOIDC() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info().Msgf("Mock OIDC server listening on %s", listener.Addr().String())
|
|
||||||
log.Info().Msgf("Issuer: %s", mock.Issuer())
|
log.Info().Msgf("mock OIDC server listening on %s", listener.Addr().String())
|
||||||
|
log.Info().Msgf("issuer: %s", mock.Issuer())
|
||||||
|
|
||||||
c := make(chan struct{})
|
c := make(chan struct{})
|
||||||
<-c
|
<-c
|
||||||
|
|
||||||
@@ -133,12 +143,13 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
|
|||||||
ErrorQueue: &mockoidc.ErrorQueue{},
|
ErrorQueue: &mockoidc.ErrorQueue{},
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
_ = mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Info().Msgf("Request: %+v", r)
|
log.Info().Msgf("request: %+v", r)
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
|
|
||||||
if r.Response != nil {
|
if r.Response != nil {
|
||||||
log.Info().Msgf("Response: %+v", r.Response)
|
log.Info().Msgf("response: %+v", r.Response)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ func init() {
|
|||||||
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
|
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
|
||||||
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||||
listNodesNamespaceFlag.Hidden = true
|
listNodesNamespaceFlag.Hidden = true
|
||||||
|
|
||||||
nodeCmd.AddCommand(listNodesCmd)
|
nodeCmd.AddCommand(listNodesCmd)
|
||||||
|
|
||||||
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
@@ -42,42 +43,51 @@ func init() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
registerNodeCmd.Flags().StringP("key", "k", "", "Key")
|
registerNodeCmd.Flags().StringP("key", "k", "", "Key")
|
||||||
|
|
||||||
err = registerNodeCmd.MarkFlagRequired("key")
|
err = registerNodeCmd.MarkFlagRequired("key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeCmd.AddCommand(registerNodeCmd)
|
nodeCmd.AddCommand(registerNodeCmd)
|
||||||
|
|
||||||
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
|
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
|
||||||
|
|
||||||
err = expireNodeCmd.MarkFlagRequired("identifier")
|
err = expireNodeCmd.MarkFlagRequired("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeCmd.AddCommand(expireNodeCmd)
|
nodeCmd.AddCommand(expireNodeCmd)
|
||||||
|
|
||||||
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
|
|
||||||
err = renameNodeCmd.MarkFlagRequired("identifier")
|
err = renameNodeCmd.MarkFlagRequired("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeCmd.AddCommand(renameNodeCmd)
|
nodeCmd.AddCommand(renameNodeCmd)
|
||||||
|
|
||||||
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
|
|
||||||
err = deleteNodeCmd.MarkFlagRequired("identifier")
|
err = deleteNodeCmd.MarkFlagRequired("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeCmd.AddCommand(deleteNodeCmd)
|
nodeCmd.AddCommand(deleteNodeCmd)
|
||||||
|
|
||||||
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
tagCmd.MarkFlagRequired("identifier")
|
_ = tagCmd.MarkFlagRequired("identifier")
|
||||||
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
|
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
|
||||||
nodeCmd.AddCommand(tagCmd)
|
nodeCmd.AddCommand(tagCmd)
|
||||||
|
|
||||||
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||||
approveRoutesCmd.MarkFlagRequired("identifier")
|
_ = approveRoutesCmd.MarkFlagRequired("identifier")
|
||||||
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
|
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
|
||||||
nodeCmd.AddCommand(approveRoutesCmd)
|
nodeCmd.AddCommand(approveRoutesCmd)
|
||||||
|
|
||||||
@@ -95,6 +105,7 @@ var registerNodeCmd = &cobra.Command{
|
|||||||
Short: "Registers a node to your network",
|
Short: "Registers a node to your network",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
user, err := cmd.Flags().GetString("user")
|
user, err := cmd.Flags().GetString("user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||||
@@ -142,6 +153,7 @@ var listNodesCmd = &cobra.Command{
|
|||||||
Aliases: []string{"ls", "show"},
|
Aliases: []string{"ls", "show"},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
user, err := cmd.Flags().GetString("user")
|
user, err := cmd.Flags().GetString("user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||||
@@ -190,6 +202,7 @@ var listNodeRoutesCmd = &cobra.Command{
|
|||||||
Aliases: []string{"lsr", "routes"},
|
Aliases: []string{"lsr", "routes"},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
@@ -233,10 +246,7 @@ var listNodeRoutesCmd = &cobra.Command{
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tableData, err := nodeRoutesToPtables(nodes)
|
tableData := nodeRoutesToPtables(nodes)
|
||||||
if err != nil {
|
|
||||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -276,7 +286,9 @@ var expireNodeCmd = &cobra.Command{
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
expiryTime := now
|
expiryTime := now
|
||||||
if expiry != "" {
|
if expiry != "" {
|
||||||
expiryTime, err = time.Parse(time.RFC3339, expiry)
|
expiryTime, err = time.Parse(time.RFC3339, expiry)
|
||||||
@@ -343,6 +355,7 @@ var renameNodeCmd = &cobra.Command{
|
|||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
newName = args[0]
|
newName = args[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
request := &v1.RenameNodeRequest{
|
request := &v1.RenameNodeRequest{
|
||||||
NodeId: identifier,
|
NodeId: identifier,
|
||||||
NewName: newName,
|
NewName: newName,
|
||||||
@@ -402,6 +415,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
confirm := false
|
confirm := false
|
||||||
|
|
||||||
force, _ := cmd.Flags().GetBool("force")
|
force, _ := cmd.Flags().GetBool("force")
|
||||||
if !force {
|
if !force {
|
||||||
confirm = util.YesNo(fmt.Sprintf(
|
confirm = util.YesNo(fmt.Sprintf(
|
||||||
@@ -417,6 +431,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
@@ -424,6 +439,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||||||
output,
|
output,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
SuccessOutput(
|
SuccessOutput(
|
||||||
map[string]string{"Result": "Node deleted"},
|
map[string]string{"Result": "Node deleted"},
|
||||||
"Node deleted",
|
"Node deleted",
|
||||||
@@ -506,15 +522,21 @@ func nodesToPtables(
|
|||||||
ephemeral = true
|
ephemeral = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastSeen time.Time
|
var (
|
||||||
var lastSeenTime string
|
lastSeen time.Time
|
||||||
|
lastSeenTime string
|
||||||
|
)
|
||||||
|
|
||||||
if node.GetLastSeen() != nil {
|
if node.GetLastSeen() != nil {
|
||||||
lastSeen = node.GetLastSeen().AsTime()
|
lastSeen = node.GetLastSeen().AsTime()
|
||||||
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
|
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
|
|
||||||
var expiry time.Time
|
var (
|
||||||
var expiryTime string
|
expiry time.Time
|
||||||
|
expiryTime string
|
||||||
|
)
|
||||||
|
|
||||||
if node.GetExpiry() != nil {
|
if node.GetExpiry() != nil {
|
||||||
expiry = node.GetExpiry().AsTime()
|
expiry = node.GetExpiry().AsTime()
|
||||||
expiryTime = expiry.Format("2006-01-02 15:04:05")
|
expiryTime = expiry.Format("2006-01-02 15:04:05")
|
||||||
@@ -523,6 +545,7 @@ func nodesToPtables(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var machineKey key.MachinePublic
|
var machineKey key.MachinePublic
|
||||||
|
|
||||||
err := machineKey.UnmarshalText(
|
err := machineKey.UnmarshalText(
|
||||||
[]byte(node.GetMachineKey()),
|
[]byte(node.GetMachineKey()),
|
||||||
)
|
)
|
||||||
@@ -531,6 +554,7 @@ func nodesToPtables(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
|
|
||||||
err = nodeKey.UnmarshalText(
|
err = nodeKey.UnmarshalText(
|
||||||
[]byte(node.GetNodeKey()),
|
[]byte(node.GetNodeKey()),
|
||||||
)
|
)
|
||||||
@@ -572,8 +596,11 @@ func nodesToPtables(
|
|||||||
user = pterm.LightYellow(node.GetUser().GetName())
|
user = pterm.LightYellow(node.GetUser().GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
var IPV4Address string
|
var (
|
||||||
var IPV6Address string
|
IPV4Address string
|
||||||
|
IPV6Address string
|
||||||
|
)
|
||||||
|
|
||||||
for _, addr := range node.GetIpAddresses() {
|
for _, addr := range node.GetIpAddresses() {
|
||||||
if netip.MustParseAddr(addr).Is4() {
|
if netip.MustParseAddr(addr).Is4() {
|
||||||
IPV4Address = addr
|
IPV4Address = addr
|
||||||
@@ -608,7 +635,7 @@ func nodesToPtables(
|
|||||||
|
|
||||||
func nodeRoutesToPtables(
|
func nodeRoutesToPtables(
|
||||||
nodes []*v1.Node,
|
nodes []*v1.Node,
|
||||||
) (pterm.TableData, error) {
|
) pterm.TableData {
|
||||||
tableHeader := []string{
|
tableHeader := []string{
|
||||||
"ID",
|
"ID",
|
||||||
"Hostname",
|
"Hostname",
|
||||||
@@ -632,7 +659,7 @@ func nodeRoutesToPtables(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tableData, nil
|
return tableData
|
||||||
}
|
}
|
||||||
|
|
||||||
var tagCmd = &cobra.Command{
|
var tagCmd = &cobra.Command{
|
||||||
@@ -641,6 +668,7 @@ var tagCmd = &cobra.Command{
|
|||||||
Aliases: []string{"tags", "t"},
|
Aliases: []string{"tags", "t"},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@@ -654,6 +682,7 @@ var tagCmd = &cobra.Command{
|
|||||||
output,
|
output,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
tagsToSet, err := cmd.Flags().GetStringSlice("tags")
|
tagsToSet, err := cmd.Flags().GetStringSlice("tags")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
@@ -668,6 +697,7 @@ var tagCmd = &cobra.Command{
|
|||||||
NodeId: identifier,
|
NodeId: identifier,
|
||||||
Tags: tagsToSet,
|
Tags: tagsToSet,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.SetTags(ctx, request)
|
resp, err := client.SetTags(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
@@ -692,6 +722,7 @@ var approveRoutesCmd = &cobra.Command{
|
|||||||
Short: "Manage the approved routes of a node",
|
Short: "Manage the approved routes of a node",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@@ -705,6 +736,7 @@ var approveRoutesCmd = &cobra.Command{
|
|||||||
output,
|
output,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := cmd.Flags().GetStringSlice("routes")
|
routes, err := cmd.Flags().GetStringSlice("routes")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
@@ -719,6 +751,7 @@ var approveRoutesCmd = &cobra.Command{
|
|||||||
NodeId: identifier,
|
NodeId: identifier,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.SetApprovedRoutes(ctx, request)
|
resp, err := client.SetApprovedRoutes(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bypassFlag = "bypass-grpc-and-access-database-directly"
|
bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -26,16 +26,22 @@ func init() {
|
|||||||
policyCmd.AddCommand(getPolicy)
|
policyCmd.AddCommand(getPolicy)
|
||||||
|
|
||||||
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||||
if err := setPolicy.MarkFlagRequired("file"); err != nil {
|
|
||||||
|
err := setPolicy.MarkFlagRequired("file")
|
||||||
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("")
|
log.Fatal().Err(err).Msg("")
|
||||||
}
|
}
|
||||||
|
|
||||||
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
|
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
|
||||||
policyCmd.AddCommand(setPolicy)
|
policyCmd.AddCommand(setPolicy)
|
||||||
|
|
||||||
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||||
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
|
|
||||||
|
err = checkPolicy.MarkFlagRequired("file")
|
||||||
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("")
|
log.Fatal().Err(err).Msg("")
|
||||||
}
|
}
|
||||||
|
|
||||||
policyCmd.AddCommand(checkPolicy)
|
policyCmd.AddCommand(checkPolicy)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,9 +56,12 @@ var getPolicy = &cobra.Command{
|
|||||||
Aliases: []string{"show", "view", "fetch"},
|
Aliases: []string{"show", "view", "fetch"},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
output, _ := cmd.Flags().GetString("output")
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
var policy string
|
var policy string
|
||||||
|
|
||||||
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
|
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
|
||||||
confirm := false
|
confirm := false
|
||||||
|
|
||||||
force, _ := cmd.Flags().GetBool("force")
|
force, _ := cmd.Flags().GetBool("force")
|
||||||
if !force {
|
if !force {
|
||||||
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
|
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
|
||||||
@@ -128,6 +137,7 @@ var setPolicy = &cobra.Command{
|
|||||||
|
|
||||||
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
|
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
|
||||||
confirm := false
|
confirm := false
|
||||||
|
|
||||||
force, _ := cmd.Flags().GetBool("force")
|
force, _ := cmd.Flags().GetBool("force")
|
||||||
if !force {
|
if !force {
|
||||||
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
|
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
|
||||||
@@ -173,7 +183,7 @@ var setPolicy = &cobra.Command{
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr
|
||||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ var listPreAuthKeys = &cobra.Command{
|
|||||||
"Owner",
|
"Owner",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range response.GetPreAuthKeys() {
|
for _, key := range response.GetPreAuthKeys() {
|
||||||
expiration := "-"
|
expiration := "-"
|
||||||
if key.GetExpiration() != nil {
|
if key.GetExpiration() != nil {
|
||||||
@@ -105,8 +106,8 @@ var listPreAuthKeys = &cobra.Command{
|
|||||||
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||||
owner,
|
owner,
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
|
|||||||
@@ -45,15 +45,16 @@ func initConfig() {
|
|||||||
if cfgFile == "" {
|
if cfgFile == "" {
|
||||||
cfgFile = os.Getenv("HEADSCALE_CONFIG")
|
cfgFile = os.Getenv("HEADSCALE_CONFIG")
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfgFile != "" {
|
if cfgFile != "" {
|
||||||
err := types.LoadConfig(cfgFile, true)
|
err := types.LoadConfig(cfgFile, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile)
|
log.Fatal().Caller().Err(err).Msgf("error loading config file %s", cfgFile)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err := types.LoadConfig("", false)
|
err := types.LoadConfig("", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Caller().Err(err).Msgf("Error loading config")
|
log.Fatal().Caller().Err(err).Msgf("error loading config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +81,7 @@ func initConfig() {
|
|||||||
Repository: "headscale",
|
Repository: "headscale",
|
||||||
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
|
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := latest.Check(githubTag, versionInfo.Version)
|
res, err := latest.Check(githubTag, versionInfo.Version)
|
||||||
if err == nil && res.Outdated {
|
if err == nil && res.Outdated {
|
||||||
//nolint
|
//nolint
|
||||||
@@ -101,6 +103,7 @@ func isPreReleaseVersion(version string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +143,8 @@ https://github.com/juanfont/headscale`,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Execute() {
|
func Execute() {
|
||||||
if err := rootCmd.Execute(); err != nil {
|
err := rootCmd.Execute()
|
||||||
|
if err != nil {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,18 +23,17 @@ var serveCmd = &cobra.Command{
|
|||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
app, err := newHeadscaleServerWithConfig()
|
app, err := newHeadscaleServerWithConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var squibbleErr squibble.ValidationError
|
if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok {
|
||||||
if errors.As(err, &squibbleErr) {
|
|
||||||
fmt.Printf("SQLite schema failed to validate:\n")
|
fmt.Printf("SQLite schema failed to validate:\n")
|
||||||
fmt.Println(squibbleErr.Diff)
|
fmt.Println(squibbleErr.Diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Fatal().Caller().Err(err).Msg("Error initializing")
|
log.Fatal().Caller().Err(err).Msg("error initializing")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.Serve()
|
err = app.Serve()
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.")
|
log.Fatal().Caller().Err(err).Msg("headscale ran into an error and had to shut down")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,12 +8,19 @@ import (
|
|||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
"github.com/pterm/pterm"
|
"github.com/pterm/pterm"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CLI user errors.
|
||||||
|
var (
|
||||||
|
errFlagRequired = errors.New("--name or --identifier flag is required")
|
||||||
|
errMultipleUsersMatch = errors.New("multiple users match query, specify an ID")
|
||||||
|
)
|
||||||
|
|
||||||
func usernameAndIDFlag(cmd *cobra.Command) {
|
func usernameAndIDFlag(cmd *cobra.Command) {
|
||||||
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
|
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
|
||||||
cmd.Flags().StringP("name", "n", "", "Username")
|
cmd.Flags().StringP("name", "n", "", "Username")
|
||||||
@@ -23,12 +30,12 @@ func usernameAndIDFlag(cmd *cobra.Command) {
|
|||||||
// If both are empty, it will exit the program with an error.
|
// If both are empty, it will exit the program with an error.
|
||||||
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||||
username, _ := cmd.Flags().GetString("name")
|
username, _ := cmd.Flags().GetString("name")
|
||||||
|
|
||||||
identifier, _ := cmd.Flags().GetInt64("identifier")
|
identifier, _ := cmd.Flags().GetInt64("identifier")
|
||||||
if username == "" && identifier < 0 {
|
if username == "" && identifier < 0 {
|
||||||
err := errors.New("--name or --identifier flag is required")
|
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
errFlagRequired,
|
||||||
"Cannot rename user: "+status.Convert(err).Message(),
|
"Cannot rename user: "+status.Convert(errFlagRequired).Message(),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -50,7 +57,8 @@ func init() {
|
|||||||
userCmd.AddCommand(renameUserCmd)
|
userCmd.AddCommand(renameUserCmd)
|
||||||
usernameAndIDFlag(renameUserCmd)
|
usernameAndIDFlag(renameUserCmd)
|
||||||
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
|
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
|
||||||
renameNodeCmd.MarkFlagRequired("new-name")
|
|
||||||
|
_ = renameNodeCmd.MarkFlagRequired("new-name")
|
||||||
}
|
}
|
||||||
|
|
||||||
var errMissingParameter = errors.New("missing parameters")
|
var errMissingParameter = errors.New("missing parameters")
|
||||||
@@ -81,7 +89,7 @@ var createUserCmd = &cobra.Command{
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
|
log.Trace().Interface(zf.Client, client).Msg("obtained gRPC client")
|
||||||
|
|
||||||
request := &v1.CreateUserRequest{Name: userName}
|
request := &v1.CreateUserRequest{Name: userName}
|
||||||
|
|
||||||
@@ -94,7 +102,7 @@ var createUserCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
|
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
|
||||||
if _, err := url.Parse(pictureURL); err != nil {
|
if _, err := url.Parse(pictureURL); err != nil { //nolint:noinlineerr
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
@@ -104,10 +112,12 @@ var createUserCmd = &cobra.Command{
|
|||||||
output,
|
output,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
request.PictureUrl = pictureURL
|
request.PictureUrl = pictureURL
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
|
log.Trace().Interface(zf.Request, request).Msg("sending CreateUser request")
|
||||||
|
|
||||||
response, err := client.CreateUser(ctx, request)
|
response, err := client.CreateUser(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
@@ -148,7 +158,7 @@ var destroyUserCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(users.GetUsers()) != 1 {
|
if len(users.GetUsers()) != 1 {
|
||||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
err := errMultipleUsersMatch
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
"Error: "+status.Convert(err).Message(),
|
"Error: "+status.Convert(err).Message(),
|
||||||
@@ -159,6 +169,7 @@ var destroyUserCmd = &cobra.Command{
|
|||||||
user := users.GetUsers()[0]
|
user := users.GetUsers()[0]
|
||||||
|
|
||||||
confirm := false
|
confirm := false
|
||||||
|
|
||||||
force, _ := cmd.Flags().GetBool("force")
|
force, _ := cmd.Flags().GetBool("force")
|
||||||
if !force {
|
if !force {
|
||||||
confirm = util.YesNo(fmt.Sprintf(
|
confirm = util.YesNo(fmt.Sprintf(
|
||||||
@@ -178,6 +189,7 @@ var destroyUserCmd = &cobra.Command{
|
|||||||
output,
|
output,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
SuccessOutput(response, "User destroyed", output)
|
SuccessOutput(response, "User destroyed", output)
|
||||||
} else {
|
} else {
|
||||||
SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output)
|
SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output)
|
||||||
@@ -238,6 +250,7 @@ var listUsersCmd = &cobra.Command{
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
@@ -276,7 +289,7 @@ var renameUserCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(users.GetUsers()) != 1 {
|
if len(users.GetUsers()) != 1 {
|
||||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
err := errMultipleUsersMatch
|
||||||
ErrorOutput(
|
ErrorOutput(
|
||||||
err,
|
err,
|
||||||
"Error: "+status.Convert(err).Message(),
|
"Error: "+status.Convert(err).Message(),
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol"
|
"github.com/juanfont/headscale/hscontrol"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
@@ -57,7 +58,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
|
||||||
|
|
||||||
grpcOptions := []grpc.DialOption{
|
grpcOptions := []grpc.DialOption{
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(), //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||||
}
|
}
|
||||||
|
|
||||||
address := cfg.CLI.Address
|
address := cfg.CLI.Address
|
||||||
@@ -81,6 +82,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
|||||||
Msgf("Unable to read/write to headscale socket, do you have the correct permissions?")
|
Msgf("Unable to read/write to headscale socket, do you have the correct permissions?")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
socket.Close()
|
socket.Close()
|
||||||
|
|
||||||
grpcOptions = append(
|
grpcOptions = append(
|
||||||
@@ -92,8 +94,9 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
|||||||
// If we are not connecting to a local server, require an API key for authentication
|
// If we are not connecting to a local server, require an API key for authentication
|
||||||
apiKey := cfg.CLI.APIKey
|
apiKey := cfg.CLI.APIKey
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.")
|
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set")
|
||||||
}
|
}
|
||||||
|
|
||||||
grpcOptions = append(grpcOptions,
|
grpcOptions = append(grpcOptions,
|
||||||
grpc.WithPerRPCCredentials(tokenAuth{
|
grpc.WithPerRPCCredentials(tokenAuth{
|
||||||
token: apiKey,
|
token: apiKey,
|
||||||
@@ -118,10 +121,11 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC")
|
log.Trace().Caller().Str(zf.Address, address).Msg("connecting via gRPC")
|
||||||
conn, err := grpc.DialContext(ctx, address, grpcOptions...)
|
|
||||||
|
conn, err := grpc.DialContext(ctx, address, grpcOptions...) //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Caller().Err(err).Msgf("Could not connect: %v", err)
|
log.Fatal().Caller().Err(err).Msgf("could not connect: %v", err)
|
||||||
os.Exit(-1) // we get here if logging is suppressed (i.e., json output)
|
os.Exit(-1) // we get here if logging is suppressed (i.e., json output)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,23 +135,26 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
|||||||
}
|
}
|
||||||
|
|
||||||
func output(result any, override string, outputFormat string) string {
|
func output(result any, override string, outputFormat string) string {
|
||||||
var jsonBytes []byte
|
var (
|
||||||
var err error
|
jsonBytes []byte
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
switch outputFormat {
|
switch outputFormat {
|
||||||
case "json":
|
case "json":
|
||||||
jsonBytes, err = json.MarshalIndent(result, "", "\t")
|
jsonBytes, err = json.MarshalIndent(result, "", "\t")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to unmarshal output")
|
log.Fatal().Err(err).Msg("unmarshalling output")
|
||||||
}
|
}
|
||||||
case "json-line":
|
case "json-line":
|
||||||
jsonBytes, err = json.Marshal(result)
|
jsonBytes, err = json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to unmarshal output")
|
log.Fatal().Err(err).Msg("unmarshalling output")
|
||||||
}
|
}
|
||||||
case "yaml":
|
case "yaml":
|
||||||
jsonBytes, err = yaml.Marshal(result)
|
jsonBytes, err = yaml.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to unmarshal output")
|
log.Fatal().Err(err).Msg("unmarshalling output")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// nolint
|
// nolint
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var colors bool
|
var colors bool
|
||||||
|
|
||||||
switch l := termcolor.SupportLevel(os.Stderr); l {
|
switch l := termcolor.SupportLevel(os.Stderr); l {
|
||||||
case termcolor.Level16M:
|
case termcolor.Level16M:
|
||||||
colors = true
|
colors = true
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestConfigFileLoading(t *testing.T) {
|
func TestConfigFileLoading(t *testing.T) {
|
||||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
tmpDir := t.TempDir()
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.RemoveAll(tmpDir)
|
|
||||||
|
|
||||||
path, err := os.Getwd()
|
path, err := os.Getwd()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -48,9 +46,7 @@ func TestConfigFileLoading(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigLoading(t *testing.T) {
|
func TestConfigLoading(t *testing.T) {
|
||||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
tmpDir := t.TempDir()
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.RemoveAll(tmpDir)
|
|
||||||
|
|
||||||
path, err := os.Getwd()
|
path, err := os.Getwd()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -22,11 +22,11 @@ import (
|
|||||||
func cleanupBeforeTest(ctx context.Context) error {
|
func cleanupBeforeTest(ctx context.Context) error {
|
||||||
err := cleanupStaleTestContainers(ctx)
|
err := cleanupStaleTestContainers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to clean stale test containers: %w", err)
|
return fmt.Errorf("cleaning stale test containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pruneDockerNetworks(ctx); err != nil {
|
if err := pruneDockerNetworks(ctx); err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("failed to prune networks: %w", err)
|
return fmt.Errorf("pruning networks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -39,14 +39,14 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
|
|||||||
Force: true,
|
Force: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to remove test container: %w", err)
|
return fmt.Errorf("removing test container: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up integration test containers for this run only
|
// Clean up integration test containers for this run only
|
||||||
if runID != "" {
|
if runID != "" {
|
||||||
err := killTestContainersByRunID(ctx, runID)
|
err := killTestContainersByRunID(ctx, runID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to clean up containers for run %s: %w", runID, err)
|
return fmt.Errorf("cleaning up containers for run %s: %w", runID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,9 +55,9 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
|
|||||||
|
|
||||||
// killTestContainers terminates and removes all test containers.
|
// killTestContainers terminates and removes all test containers.
|
||||||
func killTestContainers(ctx context.Context) error {
|
func killTestContainers(ctx context.Context) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
@@ -65,12 +65,14 @@ func killTestContainers(ctx context.Context) error {
|
|||||||
All: true,
|
All: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to list containers: %w", err)
|
return fmt.Errorf("listing containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
removed := 0
|
removed := 0
|
||||||
|
|
||||||
for _, cont := range containers {
|
for _, cont := range containers {
|
||||||
shouldRemove := false
|
shouldRemove := false
|
||||||
|
|
||||||
for _, name := range cont.Names {
|
for _, name := range cont.Names {
|
||||||
if strings.Contains(name, "headscale-test-suite") ||
|
if strings.Contains(name, "headscale-test-suite") ||
|
||||||
strings.Contains(name, "hs-") ||
|
strings.Contains(name, "hs-") ||
|
||||||
@@ -107,9 +109,9 @@ func killTestContainers(ctx context.Context) error {
|
|||||||
// This function filters containers by the hi.run-id label to only affect containers
|
// This function filters containers by the hi.run-id label to only affect containers
|
||||||
// belonging to the specified test run, leaving other concurrent test runs untouched.
|
// belonging to the specified test run, leaving other concurrent test runs untouched.
|
||||||
func killTestContainersByRunID(ctx context.Context, runID string) error {
|
func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
@@ -121,7 +123,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
|
|||||||
),
|
),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to list containers for run %s: %w", runID, err)
|
return fmt.Errorf("listing containers for run %s: %w", runID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
removed := 0
|
removed := 0
|
||||||
@@ -149,9 +151,9 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
|
|||||||
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
|
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
|
||||||
// without interfering with currently running concurrent tests.
|
// without interfering with currently running concurrent tests.
|
||||||
func cleanupStaleTestContainers(ctx context.Context) error {
|
func cleanupStaleTestContainers(ctx context.Context) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
@@ -164,7 +166,7 @@ func cleanupStaleTestContainers(ctx context.Context) error {
|
|||||||
),
|
),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to list stopped containers: %w", err)
|
return fmt.Errorf("listing stopped containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
removed := 0
|
removed := 0
|
||||||
@@ -223,15 +225,15 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container
|
|||||||
|
|
||||||
// pruneDockerNetworks removes unused Docker networks.
|
// pruneDockerNetworks removes unused Docker networks.
|
||||||
func pruneDockerNetworks(ctx context.Context) error {
|
func pruneDockerNetworks(ctx context.Context) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
report, err := cli.NetworksPrune(ctx, filters.Args{})
|
report, err := cli.NetworksPrune(ctx, filters.Args{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to prune networks: %w", err)
|
return fmt.Errorf("pruning networks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(report.NetworksDeleted) > 0 {
|
if len(report.NetworksDeleted) > 0 {
|
||||||
@@ -245,9 +247,9 @@ func pruneDockerNetworks(ctx context.Context) error {
|
|||||||
|
|
||||||
// cleanOldImages removes test-related and old dangling Docker images.
|
// cleanOldImages removes test-related and old dangling Docker images.
|
||||||
func cleanOldImages(ctx context.Context) error {
|
func cleanOldImages(ctx context.Context) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
@@ -255,12 +257,14 @@ func cleanOldImages(ctx context.Context) error {
|
|||||||
All: true,
|
All: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to list images: %w", err)
|
return fmt.Errorf("listing images: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
removed := 0
|
removed := 0
|
||||||
|
|
||||||
for _, img := range images {
|
for _, img := range images {
|
||||||
shouldRemove := false
|
shouldRemove := false
|
||||||
|
|
||||||
for _, tag := range img.RepoTags {
|
for _, tag := range img.RepoTags {
|
||||||
if strings.Contains(tag, "hs-") ||
|
if strings.Contains(tag, "hs-") ||
|
||||||
strings.Contains(tag, "headscale-integration") ||
|
strings.Contains(tag, "headscale-integration") ||
|
||||||
@@ -295,18 +299,19 @@ func cleanOldImages(ctx context.Context) error {
|
|||||||
|
|
||||||
// cleanCacheVolume removes the Docker volume used for Go module cache.
|
// cleanCacheVolume removes the Docker volume used for Go module cache.
|
||||||
func cleanCacheVolume(ctx context.Context) error {
|
func cleanCacheVolume(ctx context.Context) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
volumeName := "hs-integration-go-cache"
|
volumeName := "hs-integration-go-cache"
|
||||||
|
|
||||||
err = cli.VolumeRemove(ctx, volumeName, true)
|
err = cli.VolumeRemove(ctx, volumeName, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errdefs.IsNotFound(err) {
|
if errdefs.IsNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||||
fmt.Printf("Go module cache volume not found: %s\n", volumeName)
|
fmt.Printf("Go module cache volume not found: %s\n", volumeName)
|
||||||
} else if errdefs.IsConflict(err) {
|
} else if errdefs.IsConflict(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||||
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
|
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)
|
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)
|
||||||
@@ -330,7 +335,7 @@ func cleanCacheVolume(ctx context.Context) error {
|
|||||||
func cleanupSuccessfulTestArtifacts(logsDir string, verbose bool) error {
|
func cleanupSuccessfulTestArtifacts(logsDir string, verbose bool) error {
|
||||||
entries, err := os.ReadDir(logsDir)
|
entries, err := os.ReadDir(logsDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read logs directory: %w", err)
|
return fmt.Errorf("reading logs directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
148
cmd/hi/docker.go
148
cmd/hi/docker.go
@@ -22,17 +22,22 @@ import (
|
|||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultDirPerm = 0o755
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrTestFailed = errors.New("test failed")
|
ErrTestFailed = errors.New("test failed")
|
||||||
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
|
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
|
||||||
ErrNoDockerContext = errors.New("no docker context found")
|
ErrNoDockerContext = errors.New("no docker context found")
|
||||||
|
ErrMemoryLimitViolations = errors.New("container(s) exceeded memory limits")
|
||||||
)
|
)
|
||||||
|
|
||||||
// runTestContainer executes integration tests in a Docker container.
|
// runTestContainer executes integration tests in a Docker container.
|
||||||
|
//
|
||||||
|
//nolint:gocyclo // complex test orchestration function
|
||||||
func runTestContainer(ctx context.Context, config *RunConfig) error {
|
func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
@@ -48,19 +53,21 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
|
|
||||||
absLogsDir, err := filepath.Abs(logsDir)
|
absLogsDir, err := filepath.Abs(logsDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get absolute path for logs directory: %w", err)
|
return fmt.Errorf("getting absolute path for logs directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const dirPerm = 0o755
|
const dirPerm = 0o755
|
||||||
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil {
|
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
return fmt.Errorf("creating logs directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.CleanBefore {
|
if config.CleanBefore {
|
||||||
if config.Verbose {
|
if config.Verbose {
|
||||||
log.Printf("Running pre-test cleanup...")
|
log.Printf("Running pre-test cleanup...")
|
||||||
}
|
}
|
||||||
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
|
|
||||||
|
err := cleanupBeforeTest(ctx)
|
||||||
|
if err != nil && config.Verbose {
|
||||||
log.Printf("Warning: pre-test cleanup failed: %v", err)
|
log.Printf("Warning: pre-test cleanup failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,21 +78,21 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
imageName := "golang:" + config.GoVersion
|
imageName := "golang:" + config.GoVersion
|
||||||
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil {
|
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("failed to ensure image availability: %w", err)
|
return fmt.Errorf("ensuring image availability: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := createGoTestContainer(ctx, cli, config, containerName, absLogsDir, goTestCmd)
|
resp, err := createGoTestContainer(ctx, cli, config, containerName, absLogsDir, goTestCmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create container: %w", err)
|
return fmt.Errorf("creating container: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Verbose {
|
if config.Verbose {
|
||||||
log.Printf("Created container: %s", resp.ID)
|
log.Printf("Created container: %s", resp.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("failed to start container: %w", err)
|
return fmt.Errorf("starting container: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Starting test: %s", config.TestPattern)
|
log.Printf("Starting test: %s", config.TestPattern)
|
||||||
@@ -95,13 +102,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
|
|
||||||
// Start stats collection for container resource monitoring (if enabled)
|
// Start stats collection for container resource monitoring (if enabled)
|
||||||
var statsCollector *StatsCollector
|
var statsCollector *StatsCollector
|
||||||
|
|
||||||
if config.Stats {
|
if config.Stats {
|
||||||
var err error
|
var err error
|
||||||
statsCollector, err = NewStatsCollector()
|
|
||||||
|
statsCollector, err = NewStatsCollector(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if config.Verbose {
|
if config.Verbose {
|
||||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statsCollector = nil
|
statsCollector = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +120,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
|
|
||||||
// Start stats collection immediately - no need for complex retry logic
|
// Start stats collection immediately - no need for complex retry logic
|
||||||
// The new implementation monitors Docker events and will catch containers as they start
|
// The new implementation monitors Docker events and will catch containers as they start
|
||||||
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
err := statsCollector.StartCollection(ctx, runID, config.Verbose)
|
||||||
|
if err != nil {
|
||||||
if config.Verbose {
|
if config.Verbose {
|
||||||
log.Printf("Warning: failed to start stats collection: %v", err)
|
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||||
}
|
}
|
||||||
@@ -122,12 +133,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||||
|
|
||||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||||
if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose {
|
waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose)
|
||||||
|
if waitErr != nil && config.Verbose {
|
||||||
log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
|
log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract artifacts from test containers before cleanup
|
// Extract artifacts from test containers before cleanup
|
||||||
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose {
|
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { //nolint:noinlineerr
|
||||||
log.Printf("Warning: failed to extract artifacts from containers: %v", err)
|
log.Printf("Warning: failed to extract artifacts from containers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,12 +152,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
if len(violations) > 0 {
|
if len(violations) > 0 {
|
||||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||||
log.Printf("=================================")
|
log.Printf("=================================")
|
||||||
|
|
||||||
for _, violation := range violations {
|
for _, violation := range violations {
|
||||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
return fmt.Errorf("test failed: %d %w", len(violations), ErrMemoryLimitViolations)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +189,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("test execution failed: %w", err)
|
return fmt.Errorf("executing test: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if exitCode != 0 {
|
if exitCode != 0 {
|
||||||
@@ -210,7 +223,7 @@ func buildGoTestCommand(config *RunConfig) []string {
|
|||||||
func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunConfig, containerName, logsDir string, goTestCmd []string) (container.CreateResponse, error) {
|
func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunConfig, containerName, logsDir string, goTestCmd []string) (container.CreateResponse, error) {
|
||||||
pwd, err := os.Getwd()
|
pwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return container.CreateResponse{}, fmt.Errorf("failed to get working directory: %w", err)
|
return container.CreateResponse{}, fmt.Errorf("getting working directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
projectRoot := findProjectRoot(pwd)
|
projectRoot := findProjectRoot(pwd)
|
||||||
@@ -312,7 +325,7 @@ func streamAndWait(ctx context.Context, cli *client.Client, containerID string)
|
|||||||
Follow: true,
|
Follow: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, fmt.Errorf("failed to get container logs: %w", err)
|
return -1, fmt.Errorf("getting container logs: %w", err)
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer out.Close()
|
||||||
|
|
||||||
@@ -324,7 +337,7 @@ func streamAndWait(ctx context.Context, cli *client.Client, containerID string)
|
|||||||
select {
|
select {
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, fmt.Errorf("error waiting for container: %w", err)
|
return -1, fmt.Errorf("waiting for container: %w", err)
|
||||||
}
|
}
|
||||||
case status := <-statusCh:
|
case status := <-statusCh:
|
||||||
return int(status.StatusCode), nil
|
return int(status.StatusCode), nil
|
||||||
@@ -338,7 +351,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||||||
// First, get all related test containers
|
// First, get all related test containers
|
||||||
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
|
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to list containers: %w", err)
|
return fmt.Errorf("listing containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||||
@@ -347,6 +360,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||||||
maxWaitTime := 10 * time.Second
|
maxWaitTime := 10 * time.Second
|
||||||
checkInterval := 500 * time.Millisecond
|
checkInterval := 500 * time.Millisecond
|
||||||
timeout := time.After(maxWaitTime)
|
timeout := time.After(maxWaitTime)
|
||||||
|
|
||||||
ticker := time.NewTicker(checkInterval)
|
ticker := time.NewTicker(checkInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -356,6 +370,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
|
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
allFinalized := true
|
allFinalized := true
|
||||||
@@ -366,12 +381,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
|
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if container is in a final state
|
// Check if container is in a final state
|
||||||
if !isContainerFinalized(inspect.State) {
|
if !isContainerFinalized(inspect.State) {
|
||||||
allFinalized = false
|
allFinalized = false
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
||||||
}
|
}
|
||||||
@@ -384,6 +401,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("All test containers finalized, ready for artifact extraction")
|
log.Printf("All test containers finalized, ready for artifact extraction")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -400,13 +418,15 @@ func isContainerFinalized(state *container.State) bool {
|
|||||||
func findProjectRoot(startPath string) string {
|
func findProjectRoot(startPath string) string {
|
||||||
current := startPath
|
current := startPath
|
||||||
for {
|
for {
|
||||||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
|
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { //nolint:noinlineerr
|
||||||
return current
|
return current
|
||||||
}
|
}
|
||||||
|
|
||||||
parent := filepath.Dir(current)
|
parent := filepath.Dir(current)
|
||||||
if parent == current {
|
if parent == current {
|
||||||
return startPath
|
return startPath
|
||||||
}
|
}
|
||||||
|
|
||||||
current = parent
|
current = parent
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -416,6 +436,7 @@ func boolToInt(b bool) int {
|
|||||||
if b {
|
if b {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,13 +449,14 @@ type DockerContext struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createDockerClient creates a Docker client with context detection.
|
// createDockerClient creates a Docker client with context detection.
|
||||||
func createDockerClient() (*client.Client, error) {
|
func createDockerClient(ctx context.Context) (*client.Client, error) {
|
||||||
contextInfo, err := getCurrentDockerContext()
|
contextInfo, err := getCurrentDockerContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||||
}
|
}
|
||||||
|
|
||||||
var clientOpts []client.Opt
|
var clientOpts []client.Opt
|
||||||
|
|
||||||
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
|
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
|
||||||
|
|
||||||
if contextInfo != nil {
|
if contextInfo != nil {
|
||||||
@@ -444,6 +466,7 @@ func createDockerClient() (*client.Client, error) {
|
|||||||
if runConfig.Verbose {
|
if runConfig.Verbose {
|
||||||
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientOpts = append(clientOpts, client.WithHost(host))
|
clientOpts = append(clientOpts, client.WithHost(host))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -458,16 +481,17 @@ func createDockerClient() (*client.Client, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCurrentDockerContext retrieves the current Docker context information.
|
// getCurrentDockerContext retrieves the current Docker context information.
|
||||||
func getCurrentDockerContext() (*DockerContext, error) {
|
func getCurrentDockerContext(ctx context.Context) (*DockerContext, error) {
|
||||||
cmd := exec.Command("docker", "context", "inspect")
|
cmd := exec.CommandContext(ctx, "docker", "context", "inspect")
|
||||||
|
|
||||||
output, err := cmd.Output()
|
output, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get docker context: %w", err)
|
return nil, fmt.Errorf("getting docker context: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var contexts []DockerContext
|
var contexts []DockerContext
|
||||||
if err := json.Unmarshal(output, &contexts); err != nil {
|
if err := json.Unmarshal(output, &contexts); err != nil { //nolint:noinlineerr
|
||||||
return nil, fmt.Errorf("failed to parse docker context: %w", err)
|
return nil, fmt.Errorf("parsing docker context: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(contexts) > 0 {
|
if len(contexts) > 0 {
|
||||||
@@ -486,12 +510,13 @@ func getDockerSocketPath() string {
|
|||||||
|
|
||||||
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||||
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
_, _, err := cli.ImageInspectWithRaw(ctx, imageName) //nolint:staticcheck // SA1019: deprecated but functional
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if client.IsErrNotFound(err) {
|
if client.IsErrNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
|
||||||
|
return false, fmt.Errorf("inspecting image %s: %w", imageName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -502,13 +527,14 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
|||||||
// First check if image is available locally
|
// First check if image is available locally
|
||||||
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check local image availability: %w", err)
|
return fmt.Errorf("checking local image availability: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if available {
|
if available {
|
||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Image %s is available locally", imageName)
|
log.Printf("Image %s is available locally", imageName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -519,20 +545,21 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
|||||||
|
|
||||||
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to pull image %s: %w", imageName, err)
|
return fmt.Errorf("pulling image %s: %w", imageName, err)
|
||||||
}
|
}
|
||||||
defer reader.Close()
|
defer reader.Close()
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
_, err = io.Copy(os.Stdout, reader)
|
_, err = io.Copy(os.Stdout, reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read pull output: %w", err)
|
return fmt.Errorf("reading pull output: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err = io.Copy(io.Discard, reader)
|
_, err = io.Copy(io.Discard, reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read pull output: %w", err)
|
return fmt.Errorf("reading pull output: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Image %s pulled successfully", imageName)
|
log.Printf("Image %s pulled successfully", imageName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -547,9 +574,11 @@ func listControlFiles(logsDir string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var logFiles []string
|
var (
|
||||||
var dataFiles []string
|
logFiles []string
|
||||||
var dataDirs []string
|
dataFiles []string
|
||||||
|
dataDirs []string
|
||||||
|
)
|
||||||
|
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
name := entry.Name()
|
name := entry.Name()
|
||||||
@@ -578,6 +607,7 @@ func listControlFiles(logsDir string) {
|
|||||||
|
|
||||||
if len(logFiles) > 0 {
|
if len(logFiles) > 0 {
|
||||||
log.Printf("Headscale logs:")
|
log.Printf("Headscale logs:")
|
||||||
|
|
||||||
for _, file := range logFiles {
|
for _, file := range logFiles {
|
||||||
log.Printf(" %s", file)
|
log.Printf(" %s", file)
|
||||||
}
|
}
|
||||||
@@ -585,9 +615,11 @@ func listControlFiles(logsDir string) {
|
|||||||
|
|
||||||
if len(dataFiles) > 0 || len(dataDirs) > 0 {
|
if len(dataFiles) > 0 || len(dataDirs) > 0 {
|
||||||
log.Printf("Headscale data:")
|
log.Printf("Headscale data:")
|
||||||
|
|
||||||
for _, file := range dataFiles {
|
for _, file := range dataFiles {
|
||||||
log.Printf(" %s", file)
|
log.Printf(" %s", file)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dir := range dataDirs {
|
for _, dir := range dataDirs {
|
||||||
log.Printf(" %s/", dir)
|
log.Printf(" %s/", dir)
|
||||||
}
|
}
|
||||||
@@ -596,25 +628,27 @@ func listControlFiles(logsDir string) {
|
|||||||
|
|
||||||
// extractArtifactsFromContainers collects container logs and files from the specific test run.
|
// extractArtifactsFromContainers collects container logs and files from the specific test run.
|
||||||
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
|
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
return fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer cli.Close()
|
||||||
|
|
||||||
// List all containers
|
// List all containers
|
||||||
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
|
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to list containers: %w", err)
|
return fmt.Errorf("listing containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get containers from the specific test run
|
// Get containers from the specific test run
|
||||||
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||||
|
|
||||||
extractedCount := 0
|
extractedCount := 0
|
||||||
|
|
||||||
for _, cont := range currentTestContainers {
|
for _, cont := range currentTestContainers {
|
||||||
// Extract container logs and tar files
|
// Extract container logs and tar files
|
||||||
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
|
err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose)
|
||||||
|
if err != nil {
|
||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
|
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
|
||||||
}
|
}
|
||||||
@@ -622,6 +656,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
|
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
|
||||||
}
|
}
|
||||||
|
|
||||||
extractedCount++
|
extractedCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -645,11 +680,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
|||||||
|
|
||||||
// Find the test container to get its run ID label
|
// Find the test container to get its run ID label
|
||||||
var runID string
|
var runID string
|
||||||
|
|
||||||
for _, cont := range containers {
|
for _, cont := range containers {
|
||||||
if cont.ID == testContainerID {
|
if cont.ID == testContainerID {
|
||||||
if cont.Labels != nil {
|
if cont.Labels != nil {
|
||||||
runID = cont.Labels["hi.run-id"]
|
runID = cont.Labels["hi.run-id"]
|
||||||
}
|
}
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -690,18 +727,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
|||||||
// extractContainerArtifacts saves logs and tar files from a container.
|
// extractContainerArtifacts saves logs and tar files from a container.
|
||||||
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
||||||
// Ensure the logs directory exists
|
// Ensure the logs directory exists
|
||||||
if err := os.MkdirAll(logsDir, 0o755); err != nil {
|
err := os.MkdirAll(logsDir, defaultDirPerm)
|
||||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating logs directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract container logs
|
// Extract container logs
|
||||||
if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
err = extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||||
return fmt.Errorf("failed to extract logs: %w", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("extracting logs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract tar files for headscale containers only
|
// Extract tar files for headscale containers only
|
||||||
if strings.HasPrefix(containerName, "hs-") {
|
if strings.HasPrefix(containerName, "hs-") {
|
||||||
if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||||
|
if err != nil {
|
||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
|
log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
|
||||||
}
|
}
|
||||||
@@ -723,7 +763,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
|||||||
Tail: "all",
|
Tail: "all",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get container logs: %w", err)
|
return fmt.Errorf("getting container logs: %w", err)
|
||||||
}
|
}
|
||||||
defer logReader.Close()
|
defer logReader.Close()
|
||||||
|
|
||||||
@@ -737,17 +777,17 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
|||||||
// Demultiplex the Docker logs stream to separate stdout and stderr
|
// Demultiplex the Docker logs stream to separate stdout and stderr
|
||||||
_, err = stdcopy.StdCopy(&stdoutBuf, &stderrBuf, logReader)
|
_, err = stdcopy.StdCopy(&stdoutBuf, &stderrBuf, logReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to demultiplex container logs: %w", err)
|
return fmt.Errorf("demultiplexing container logs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write stdout logs
|
// Write stdout logs
|
||||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
|
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
|
||||||
return fmt.Errorf("failed to write stdout log: %w", err)
|
return fmt.Errorf("writing stdout log: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write stderr logs
|
// Write stderr logs
|
||||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
|
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
|
||||||
return fmt.Errorf("failed to write stderr log: %w", err)
|
return fmt.Errorf("writing stderr log: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
|
|||||||
@@ -38,13 +38,13 @@ func runDoctorCheck(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check 3: Go installation
|
// Check 3: Go installation
|
||||||
results = append(results, checkGoInstallation())
|
results = append(results, checkGoInstallation(ctx))
|
||||||
|
|
||||||
// Check 4: Git repository
|
// Check 4: Git repository
|
||||||
results = append(results, checkGitRepository())
|
results = append(results, checkGitRepository(ctx))
|
||||||
|
|
||||||
// Check 5: Required files
|
// Check 5: Required files
|
||||||
results = append(results, checkRequiredFiles())
|
results = append(results, checkRequiredFiles(ctx))
|
||||||
|
|
||||||
// Display results
|
// Display results
|
||||||
displayDoctorResults(results)
|
displayDoctorResults(results)
|
||||||
@@ -86,7 +86,7 @@ func checkDockerBinary() DoctorResult {
|
|||||||
|
|
||||||
// checkDockerDaemon verifies Docker daemon is running and accessible.
|
// checkDockerDaemon verifies Docker daemon is running and accessible.
|
||||||
func checkDockerDaemon(ctx context.Context) DoctorResult {
|
func checkDockerDaemon(ctx context.Context) DoctorResult {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
Name: "Docker Daemon",
|
Name: "Docker Daemon",
|
||||||
@@ -124,8 +124,8 @@ func checkDockerDaemon(ctx context.Context) DoctorResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkDockerContext verifies Docker context configuration.
|
// checkDockerContext verifies Docker context configuration.
|
||||||
func checkDockerContext(_ context.Context) DoctorResult {
|
func checkDockerContext(ctx context.Context) DoctorResult {
|
||||||
contextInfo, err := getCurrentDockerContext()
|
contextInfo, err := getCurrentDockerContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
Name: "Docker Context",
|
Name: "Docker Context",
|
||||||
@@ -155,7 +155,7 @@ func checkDockerContext(_ context.Context) DoctorResult {
|
|||||||
|
|
||||||
// checkDockerSocket verifies Docker socket accessibility.
|
// checkDockerSocket verifies Docker socket accessibility.
|
||||||
func checkDockerSocket(ctx context.Context) DoctorResult {
|
func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
Name: "Docker Socket",
|
Name: "Docker Socket",
|
||||||
@@ -192,7 +192,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
|||||||
|
|
||||||
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
Name: "Golang Image",
|
Name: "Golang Image",
|
||||||
@@ -251,7 +251,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkGoInstallation verifies Go is installed and working.
|
// checkGoInstallation verifies Go is installed and working.
|
||||||
func checkGoInstallation() DoctorResult {
|
func checkGoInstallation(ctx context.Context) DoctorResult {
|
||||||
_, err := exec.LookPath("go")
|
_, err := exec.LookPath("go")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
@@ -265,7 +265,8 @@ func checkGoInstallation() DoctorResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("go", "version")
|
cmd := exec.CommandContext(ctx, "go", "version")
|
||||||
|
|
||||||
output, err := cmd.Output()
|
output, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
@@ -285,8 +286,9 @@ func checkGoInstallation() DoctorResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkGitRepository verifies we're in a git repository.
|
// checkGitRepository verifies we're in a git repository.
|
||||||
func checkGitRepository() DoctorResult {
|
func checkGitRepository(ctx context.Context) DoctorResult {
|
||||||
cmd := exec.Command("git", "rev-parse", "--git-dir")
|
cmd := exec.CommandContext(ctx, "git", "rev-parse", "--git-dir")
|
||||||
|
|
||||||
err := cmd.Run()
|
err := cmd.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
@@ -308,7 +310,7 @@ func checkGitRepository() DoctorResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkRequiredFiles verifies required files exist.
|
// checkRequiredFiles verifies required files exist.
|
||||||
func checkRequiredFiles() DoctorResult {
|
func checkRequiredFiles(ctx context.Context) DoctorResult {
|
||||||
requiredFiles := []string{
|
requiredFiles := []string{
|
||||||
"go.mod",
|
"go.mod",
|
||||||
"integration/",
|
"integration/",
|
||||||
@@ -316,9 +318,12 @@ func checkRequiredFiles() DoctorResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var missingFiles []string
|
var missingFiles []string
|
||||||
|
|
||||||
for _, file := range requiredFiles {
|
for _, file := range requiredFiles {
|
||||||
cmd := exec.Command("test", "-e", file)
|
cmd := exec.CommandContext(ctx, "test", "-e", file)
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
missingFiles = append(missingFiles, file)
|
missingFiles = append(missingFiles, file)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -350,6 +355,7 @@ func displayDoctorResults(results []DoctorResult) {
|
|||||||
|
|
||||||
for _, result := range results {
|
for _, result := range results {
|
||||||
var icon string
|
var icon string
|
||||||
|
|
||||||
switch result.Status {
|
switch result.Status {
|
||||||
case "PASS":
|
case "PASS":
|
||||||
icon = "✅"
|
icon = "✅"
|
||||||
|
|||||||
@@ -79,13 +79,18 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cleanAll(ctx context.Context) error {
|
func cleanAll(ctx context.Context) error {
|
||||||
if err := killTestContainers(ctx); err != nil {
|
err := killTestContainers(ctx)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := pruneDockerNetworks(ctx); err != nil {
|
|
||||||
|
err = pruneDockerNetworks(ctx)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := cleanOldImages(ctx); err != nil {
|
|
||||||
|
err = cleanOldImages(ctx)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error {
|
|||||||
if runConfig.Verbose {
|
if runConfig.Verbose {
|
||||||
log.Printf("Running pre-flight system checks...")
|
log.Printf("Running pre-flight system checks...")
|
||||||
}
|
}
|
||||||
if err := runDoctorCheck(env.Context()); err != nil {
|
|
||||||
|
err := runDoctorCheck(env.Context())
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("pre-flight checks failed: %w", err)
|
return fmt.Errorf("pre-flight checks failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,15 +68,15 @@ func runIntegrationTest(env *command.Env) error {
|
|||||||
func detectGoVersion() string {
|
func detectGoVersion() string {
|
||||||
goModPath := filepath.Join("..", "..", "go.mod")
|
goModPath := filepath.Join("..", "..", "go.mod")
|
||||||
|
|
||||||
if _, err := os.Stat("go.mod"); err == nil {
|
if _, err := os.Stat("go.mod"); err == nil { //nolint:noinlineerr
|
||||||
goModPath = "go.mod"
|
goModPath = "go.mod"
|
||||||
} else if _, err := os.Stat("../../go.mod"); err == nil {
|
} else if _, err := os.Stat("../../go.mod"); err == nil { //nolint:noinlineerr
|
||||||
goModPath = "../../go.mod"
|
goModPath = "../../go.mod"
|
||||||
}
|
}
|
||||||
|
|
||||||
content, err := os.ReadFile(goModPath)
|
content, err := os.ReadFile(goModPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "1.25"
|
return "1.26.0"
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := splitLines(string(content))
|
lines := splitLines(string(content))
|
||||||
@@ -89,13 +91,15 @@ func detectGoVersion() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "1.25"
|
return "1.26.0"
|
||||||
}
|
}
|
||||||
|
|
||||||
// splitLines splits a string into lines without using strings.Split.
|
// splitLines splits a string into lines without using strings.Split.
|
||||||
func splitLines(s string) []string {
|
func splitLines(s string) []string {
|
||||||
var lines []string
|
var (
|
||||||
var current string
|
lines []string
|
||||||
|
current string
|
||||||
|
)
|
||||||
|
|
||||||
for _, char := range s {
|
for _, char := range s {
|
||||||
if char == '\n' {
|
if char == '\n' {
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ import (
|
|||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrStatsCollectionAlreadyStarted is returned when trying to start stats collection that is already running.
|
||||||
|
var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started")
|
||||||
|
|
||||||
// ContainerStats represents statistics for a single container.
|
// ContainerStats represents statistics for a single container.
|
||||||
type ContainerStats struct {
|
type ContainerStats struct {
|
||||||
ContainerID string
|
ContainerID string
|
||||||
@@ -44,10 +47,10 @@ type StatsCollector struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewStatsCollector creates a new stats collector instance.
|
// NewStatsCollector creates a new stats collector instance.
|
||||||
func NewStatsCollector() (*StatsCollector, error) {
|
func NewStatsCollector(ctx context.Context) (*StatsCollector, error) {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create Docker client: %w", err)
|
return nil, fmt.Errorf("creating Docker client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StatsCollector{
|
return &StatsCollector{
|
||||||
@@ -63,17 +66,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
|||||||
defer sc.mutex.Unlock()
|
defer sc.mutex.Unlock()
|
||||||
|
|
||||||
if sc.collectionStarted {
|
if sc.collectionStarted {
|
||||||
return errors.New("stats collection already started")
|
return ErrStatsCollectionAlreadyStarted
|
||||||
}
|
}
|
||||||
|
|
||||||
sc.collectionStarted = true
|
sc.collectionStarted = true
|
||||||
|
|
||||||
// Start monitoring existing containers
|
// Start monitoring existing containers
|
||||||
sc.wg.Add(1)
|
sc.wg.Add(1)
|
||||||
|
|
||||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||||
|
|
||||||
// Start Docker events monitoring for new containers
|
// Start Docker events monitoring for new containers
|
||||||
sc.wg.Add(1)
|
sc.wg.Add(1)
|
||||||
|
|
||||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
@@ -87,10 +92,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
|||||||
func (sc *StatsCollector) StopCollection() {
|
func (sc *StatsCollector) StopCollection() {
|
||||||
// Check if already stopped without holding lock
|
// Check if already stopped without holding lock
|
||||||
sc.mutex.RLock()
|
sc.mutex.RLock()
|
||||||
|
|
||||||
if !sc.collectionStarted {
|
if !sc.collectionStarted {
|
||||||
sc.mutex.RUnlock()
|
sc.mutex.RUnlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sc.mutex.RUnlock()
|
sc.mutex.RUnlock()
|
||||||
|
|
||||||
// Signal stop to all goroutines
|
// Signal stop to all goroutines
|
||||||
@@ -114,6 +121,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Failed to list existing containers: %v", err)
|
log.Printf("Failed to list existing containers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,13 +155,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
|||||||
case event := <-events:
|
case event := <-events:
|
||||||
if event.Type == "container" && event.Action == "start" {
|
if event.Type == "container" && event.Action == "start" {
|
||||||
// Get container details
|
// Get container details
|
||||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) //nolint:staticcheck // SA1019: use Actor.ID
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to types.Container format for consistency
|
// Convert to types.Container format for consistency
|
||||||
cont := types.Container{
|
cont := types.Container{ //nolint:staticcheck // SA1019: use container.Summary
|
||||||
ID: containerInfo.ID,
|
ID: containerInfo.ID,
|
||||||
Names: []string{containerInfo.Name},
|
Names: []string{containerInfo.Name},
|
||||||
Labels: containerInfo.Config.Labels,
|
Labels: containerInfo.Config.Labels,
|
||||||
@@ -167,13 +175,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Error in Docker events stream: %v", err)
|
log.Printf("Error in Docker events stream: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldMonitorContainer determines if a container should be monitored.
|
// shouldMonitorContainer determines if a container should be monitored.
|
||||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { //nolint:staticcheck // SA1019: use container.Summary
|
||||||
// Check if it has the correct run ID label
|
// Check if it has the correct run ID label
|
||||||
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||||
return false
|
return false
|
||||||
@@ -213,6 +222,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
|
|||||||
}
|
}
|
||||||
|
|
||||||
sc.wg.Add(1)
|
sc.wg.Add(1)
|
||||||
|
|
||||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,12 +236,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer statsResponse.Body.Close()
|
defer statsResponse.Body.Close()
|
||||||
|
|
||||||
decoder := json.NewDecoder(statsResponse.Body)
|
decoder := json.NewDecoder(statsResponse.Body)
|
||||||
var prevStats *container.Stats
|
|
||||||
|
var prevStats *container.Stats //nolint:staticcheck // SA1019: use StatsResponse
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -240,12 +252,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
var stats container.Stats
|
var stats container.Stats //nolint:staticcheck // SA1019: use StatsResponse
|
||||||
if err := decoder.Decode(&stats); err != nil {
|
|
||||||
|
err := decoder.Decode(&stats)
|
||||||
|
if err != nil {
|
||||||
// EOF is expected when container stops or stream ends
|
// EOF is expected when container stops or stream ends
|
||||||
if err.Error() != "EOF" && verbose {
|
if err.Error() != "EOF" && verbose {
|
||||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,8 +276,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||||||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||||
if prevStats != nil {
|
if prevStats != nil {
|
||||||
// Get container stats reference without holding the main mutex
|
// Get container stats reference without holding the main mutex
|
||||||
var containerStats *ContainerStats
|
var (
|
||||||
var exists bool
|
containerStats *ContainerStats
|
||||||
|
exists bool
|
||||||
|
)
|
||||||
|
|
||||||
sc.mutex.RLock()
|
sc.mutex.RLock()
|
||||||
containerStats, exists = sc.containers[containerID]
|
containerStats, exists = sc.containers[containerID]
|
||||||
@@ -286,7 +303,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// calculateCPUPercent calculates CPU usage percentage from Docker stats.
|
// calculateCPUPercent calculates CPU usage percentage from Docker stats.
|
||||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
func calculateCPUPercent(prevStats, stats *container.Stats) float64 { //nolint:staticcheck // SA1019: use StatsResponse
|
||||||
// CPU calculation based on Docker's implementation
|
// CPU calculation based on Docker's implementation
|
||||||
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||||
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||||
@@ -331,10 +348,12 @@ type StatsSummary struct {
|
|||||||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||||
// Take snapshot of container references without holding main lock long
|
// Take snapshot of container references without holding main lock long
|
||||||
sc.mutex.RLock()
|
sc.mutex.RLock()
|
||||||
|
|
||||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||||
for _, containerStats := range sc.containers {
|
for _, containerStats := range sc.containers {
|
||||||
containerRefs = append(containerRefs, containerStats)
|
containerRefs = append(containerRefs, containerStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
sc.mutex.RUnlock()
|
sc.mutex.RUnlock()
|
||||||
|
|
||||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||||
@@ -384,23 +403,25 @@ func calculateStatsSummary(values []float64) StatsSummary {
|
|||||||
return StatsSummary{}
|
return StatsSummary{}
|
||||||
}
|
}
|
||||||
|
|
||||||
min := values[0]
|
minVal := values[0]
|
||||||
max := values[0]
|
maxVal := values[0]
|
||||||
sum := 0.0
|
sum := 0.0
|
||||||
|
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if value < min {
|
if value < minVal {
|
||||||
min = value
|
minVal = value
|
||||||
}
|
}
|
||||||
if value > max {
|
|
||||||
max = value
|
if value > maxVal {
|
||||||
|
maxVal = value
|
||||||
}
|
}
|
||||||
|
|
||||||
sum += value
|
sum += value
|
||||||
}
|
}
|
||||||
|
|
||||||
return StatsSummary{
|
return StatsSummary{
|
||||||
Min: min,
|
Min: minVal,
|
||||||
Max: max,
|
Max: maxVal,
|
||||||
Average: sum / float64(len(values)),
|
Average: sum / float64(len(values)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -434,6 +455,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
|
|||||||
}
|
}
|
||||||
|
|
||||||
summaries := sc.GetSummary()
|
summaries := sc.GetSummary()
|
||||||
|
|
||||||
var violations []MemoryViolation
|
var violations []MemoryViolation
|
||||||
|
|
||||||
for _, summary := range summaries {
|
for _, summary := range summaries {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -15,7 +16,10 @@ type MapConfig struct {
|
|||||||
Directory string `flag:"directory,Directory to read map responses from"`
|
Directory string `flag:"directory,Directory to read map responses from"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var mapConfig MapConfig
|
var (
|
||||||
|
mapConfig MapConfig
|
||||||
|
errDirectoryRequired = errors.New("directory is required")
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
root := command.C{
|
root := command.C{
|
||||||
@@ -40,7 +44,7 @@ func main() {
|
|||||||
// runIntegrationTest executes the integration test workflow.
|
// runIntegrationTest executes the integration test workflow.
|
||||||
func runOnline(env *command.Env) error {
|
func runOnline(env *command.Env) error {
|
||||||
if mapConfig.Directory == "" {
|
if mapConfig.Directory == "" {
|
||||||
return fmt.Errorf("directory is required")
|
return errDirectoryRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
|
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
|
||||||
@@ -57,5 +61,6 @@ func runOnline(env *command.Env) error {
|
|||||||
|
|
||||||
os.Stderr.Write(out)
|
os.Stderr.Write(out)
|
||||||
os.Stderr.Write([]byte("\n"))
|
os.Stderr.Write([]byte("\n"))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,9 +24,12 @@ We are more than happy to exchange emails, or to have dedicated calls before a P
|
|||||||
|
|
||||||
## When/Why is Feature X going to be implemented?
|
## When/Why is Feature X going to be implemented?
|
||||||
|
|
||||||
We don't know. We might be working on it. If you're interested in contributing, please post a feature request about it.
|
We use [GitHub Milestones to plan for upcoming Headscale releases](https://github.com/juanfont/headscale/milestones).
|
||||||
|
Have a look at [our current plan](https://github.com/juanfont/headscale/milestones) to get an idea when a specific
|
||||||
|
feature is about to be implemented. The release plan is subject to change at any time.
|
||||||
|
|
||||||
Please be aware that there are a number of reasons why we might not accept specific contributions:
|
If you're interested in contributing, please post a feature request about it. Please be aware that there are a number of
|
||||||
|
reasons why we might not accept specific contributions:
|
||||||
|
|
||||||
- It is not possible to implement the feature in a way that makes sense in a self-hosted environment.
|
- It is not possible to implement the feature in a way that makes sense in a self-hosted environment.
|
||||||
- Given that we are reverse-engineering Tailscale to satisfy our own curiosity, we might be interested in implementing the feature ourselves.
|
- Given that we are reverse-engineering Tailscale to satisfy our own curiosity, we might be interested in implementing the feature ourselves.
|
||||||
@@ -47,7 +50,7 @@ we have a "docker-issues" channel where you can ask for Docker-specific help to
|
|||||||
## What is the recommended update path? Can I skip multiple versions while updating?
|
## What is the recommended update path? Can I skip multiple versions while updating?
|
||||||
|
|
||||||
Please follow the steps outlined in the [upgrade guide](../setup/upgrade.md) to update your existing Headscale
|
Please follow the steps outlined in the [upgrade guide](../setup/upgrade.md) to update your existing Headscale
|
||||||
installation. Its best to update from one stable version to the next (e.g. 0.24.0 → 0.25.1 → 0.26.1) in case
|
installation. Its best to update from one stable version to the next (e.g. 0.26.0 → 0.27.1 → 0.28.0) in case
|
||||||
you are multiple releases behind. You should always pick the latest available patch release.
|
you are multiple releases behind. You should always pick the latest available patch release.
|
||||||
|
|
||||||
Be sure to check the [changelog](https://github.com/juanfont/headscale/blob/main/CHANGELOG.md) for version specific
|
Be sure to check the [changelog](https://github.com/juanfont/headscale/blob/main/CHANGELOG.md) for version specific
|
||||||
|
|||||||
@@ -245,7 +245,6 @@ Includes all devices that [have at least one tag](registration.md/#identity-mode
|
|||||||
```
|
```
|
||||||
|
|
||||||
### `autogroup:self`
|
### `autogroup:self`
|
||||||
**(EXPERIMENTAL)**
|
|
||||||
|
|
||||||
!!! warning "The current implementation of `autogroup:self` is inefficient"
|
!!! warning "The current implementation of `autogroup:self` is inefficient"
|
||||||
|
|
||||||
|
|||||||
@@ -20,5 +20,7 @@ Headscale doesn't provide a built-in web interface but users may pick one from t
|
|||||||
- [headscale-console](https://github.com/rickli-cloud/headscale-console) - WebAssembly-based client supporting SSH, VNC
|
- [headscale-console](https://github.com/rickli-cloud/headscale-console) - WebAssembly-based client supporting SSH, VNC
|
||||||
and RDP with optional self-service capabilities
|
and RDP with optional self-service capabilities
|
||||||
- [headscale-piying](https://github.com/wszgrcy/headscale-piying) - headscale web ui,support visual ACL configuration
|
- [headscale-piying](https://github.com/wszgrcy/headscale-piying) - headscale web ui,support visual ACL configuration
|
||||||
|
- [HeadControl](https://github.com/ahmadzip/HeadControl) - Minimal Headscale admin dashboard, built with Go and HTMX
|
||||||
|
- [Headscale Manager](https://github.com/hkdone/headscalemanager) - Headscale UI for Android
|
||||||
|
|
||||||
You can ask for support on our [Discord server](https://discord.gg/c84AZQhmpx) in the "web-interfaces" channel.
|
You can ask for support on our [Discord server](https://discord.gg/c84AZQhmpx) in the "web-interfaces" channel.
|
||||||
|
|||||||
@@ -185,7 +185,8 @@ You may refer to users in the Headscale policy via:
|
|||||||
|
|
||||||
- Email address
|
- Email address
|
||||||
- Username
|
- Username
|
||||||
- Provider identifier (only available in the database or from your identity provider)
|
- Provider identifier (this value is currently only available from the [API](api.md), database or directly from your
|
||||||
|
identity provider)
|
||||||
|
|
||||||
!!! note "A user identifier in the policy must contain a single `@`"
|
!!! note "A user identifier in the policy must contain a single `@`"
|
||||||
|
|
||||||
@@ -200,6 +201,34 @@ You may refer to users in the Headscale policy via:
|
|||||||
consequences for Headscale where a policy might no longer work or a user might obtain more access by hijacking an
|
consequences for Headscale where a policy might no longer work or a user might obtain more access by hijacking an
|
||||||
existing username or email address.
|
existing username or email address.
|
||||||
|
|
||||||
|
!!! tip "Howto use the provider identifier in the policy"
|
||||||
|
|
||||||
|
The provider identifier uniquely identifies an OIDC user and a well-behaving identity provider guarantees that this
|
||||||
|
value never changes for a particular user. It is usually an opaque and long string and its value is currently only
|
||||||
|
available from the [API](api.md), database or directly from your identity provider).
|
||||||
|
|
||||||
|
Use the [API](api.md) with the `/api/v1/user` endpoint to fetch the provider identifier (`providerId`). The value
|
||||||
|
(be sure to append an `@` in case the provider identifier doesn't already contain an `@` somewhere) can be used
|
||||||
|
directly to reference a user in the policy. To improve readability of the policy, one may use the `groups` section
|
||||||
|
as an alias:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:alice": [
|
||||||
|
"https://soo.example.com/oauth2/openid/59ac9125-c31b-46c5-814e-06242908cf57@"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": ["group:alice"],
|
||||||
|
"dst": ["*:*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Supported OIDC claims
|
## Supported OIDC claims
|
||||||
|
|
||||||
Headscale uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to
|
Headscale uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to
|
||||||
@@ -289,6 +318,14 @@ Console.
|
|||||||
- Kanidm is fully supported by Headscale.
|
- Kanidm is fully supported by Headscale.
|
||||||
- Groups for the [allowed groups filter](#authorize-users-with-filters) need to be specified with their full SPN, for
|
- Groups for the [allowed groups filter](#authorize-users-with-filters) need to be specified with their full SPN, for
|
||||||
example: `headscale_users@sso.example.com`.
|
example: `headscale_users@sso.example.com`.
|
||||||
|
- Kanidm sends the full SPN (`alice@sso.example.com`) as `preferred_username` by default. Headscale stores this value as
|
||||||
|
username which might be confusing as the username and email fields now contain values that look like an email address.
|
||||||
|
[Kanidm can be configured to send the short username as `preferred_username` attribute
|
||||||
|
instead](https://kanidm.github.io/kanidm/stable/integrations/oauth2.html#short-names):
|
||||||
|
```console
|
||||||
|
kanidm system oauth2 prefer-short-username <client name>
|
||||||
|
```
|
||||||
|
Once configured, the short username in Headscale will be `alice` and can be referred to as `alice@` in the policy.
|
||||||
|
|
||||||
### Keycloak
|
### Keycloak
|
||||||
|
|
||||||
|
|||||||
6
flake.lock
generated
6
flake.lock
generated
@@ -20,11 +20,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1770141374,
|
"lastModified": 1771177547,
|
||||||
"narHash": "sha256-yD4K/vRHPwXbJf5CK3JkptBA6nFWUKNX/jlFp2eKEQc=",
|
"narHash": "sha256-trTtk3WTOHz7hSw89xIIvahkgoFJYQ0G43IlqprFoMA=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "41965737c1797c1d83cfb0b644ed0840a6220bd1",
|
"rev": "ac055f38c798b0d87695240c7b761b82fc7e5bc2",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|||||||
54
flake.nix
54
flake.nix
@@ -26,8 +26,8 @@
|
|||||||
overlays.default = _: prev:
|
overlays.default = _: prev:
|
||||||
let
|
let
|
||||||
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
||||||
buildGo = pkgs.buildGo125Module;
|
buildGo = pkgs.buildGo126Module;
|
||||||
vendorHash = "sha256-jkeB9XUTEGt58fPOMpE4/e3+JQoMQTgf0RlthVBmfG0=";
|
vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
headscale = buildGo {
|
headscale = buildGo {
|
||||||
@@ -94,14 +94,46 @@
|
|||||||
subPackages = [ "." ];
|
subPackages = [ "." ];
|
||||||
};
|
};
|
||||||
|
|
||||||
# Upstream does not override buildGoModule properly,
|
# Build golangci-lint with Go 1.26 (upstream uses hardcoded Go version)
|
||||||
# importing a specific module, so comment out for now.
|
golangci-lint = buildGo rec {
|
||||||
# golangci-lint = prev.golangci-lint.override {
|
pname = "golangci-lint";
|
||||||
# buildGoModule = buildGo;
|
version = "2.9.0";
|
||||||
# };
|
|
||||||
# golangci-lint-langserver = prev.golangci-lint.override {
|
src = pkgs.fetchFromGitHub {
|
||||||
# buildGoModule = buildGo;
|
owner = "golangci";
|
||||||
# };
|
repo = "golangci-lint";
|
||||||
|
rev = "v${version}";
|
||||||
|
hash = "sha256-8LEtm1v0slKwdLBtS41OilKJLXytSxcI9fUlZbj5Gfw=";
|
||||||
|
};
|
||||||
|
|
||||||
|
vendorHash = "sha256-w8JfF6n1ylrU652HEv/cYdsOdDZz9J2uRQDqxObyhkY=";
|
||||||
|
|
||||||
|
subPackages = [ "cmd/golangci-lint" ];
|
||||||
|
|
||||||
|
nativeBuildInputs = [ pkgs.installShellFiles ];
|
||||||
|
|
||||||
|
ldflags = [
|
||||||
|
"-s"
|
||||||
|
"-w"
|
||||||
|
"-X main.version=${version}"
|
||||||
|
"-X main.commit=v${version}"
|
||||||
|
"-X main.date=1970-01-01T00:00:00Z"
|
||||||
|
];
|
||||||
|
|
||||||
|
postInstall = ''
|
||||||
|
for shell in bash zsh fish; do
|
||||||
|
HOME=$TMPDIR $out/bin/golangci-lint completion $shell > golangci-lint.$shell
|
||||||
|
installShellCompletion golangci-lint.$shell
|
||||||
|
done
|
||||||
|
'';
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
description = "Fast linters runner for Go";
|
||||||
|
homepage = "https://golangci-lint.run/";
|
||||||
|
changelog = "https://github.com/golangci/golangci-lint/blob/v${version}/CHANGELOG.md";
|
||||||
|
mainProgram = "golangci-lint";
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
# The package uses buildGo125Module, not the convention.
|
# The package uses buildGo125Module, not the convention.
|
||||||
# goreleaser = prev.goreleaser.override {
|
# goreleaser = prev.goreleaser.override {
|
||||||
@@ -132,7 +164,7 @@
|
|||||||
overlays = [ self.overlays.default ];
|
overlays = [ self.overlays.default ];
|
||||||
inherit system;
|
inherit system;
|
||||||
};
|
};
|
||||||
buildDeps = with pkgs; [ git go_1_25 gnumake ];
|
buildDeps = with pkgs; [ git go_1_26 gnumake ];
|
||||||
devDeps = with pkgs;
|
devDeps = with pkgs;
|
||||||
buildDeps
|
buildDeps
|
||||||
++ [
|
++ [
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/juanfont/headscale
|
module github.com/juanfont/headscale
|
||||||
|
|
||||||
go 1.25.5
|
go 1.26.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/arl/statsviz v0.8.0
|
github.com/arl/statsviz v0.8.0
|
||||||
|
|||||||
137
hscontrol/app.go
137
hscontrol/app.go
@@ -115,13 +115,14 @@ var (
|
|||||||
|
|
||||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if profilingEnabled {
|
if profilingEnabled {
|
||||||
runtime.SetBlockProfileRate(1)
|
runtime.SetBlockProfileRate(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
|
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
|
return nil, fmt.Errorf("reading or creating Noise protocol private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := state.NewState(cfg)
|
s, err := state.NewState(cfg)
|
||||||
@@ -140,27 +141,30 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||||
node, ok := app.state.GetNodeByID(ni)
|
node, ok := app.state.GetNodeByID(ni)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
|
log.Error().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed")
|
||||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
|
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed because node not found in NodeStore")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policyChanged, err := app.state.DeleteNode(node)
|
policyChanged, err := app.state.DeleteNode(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed")
|
log.Error().Err(err).EmbedObject(node).Msg("ephemeral node deletion failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Change(policyChanged)
|
app.Change(policyChanged)
|
||||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached")
|
log.Debug().Caller().EmbedObject(node).Msg("ephemeral node deleted because garbage collection timeout reached")
|
||||||
})
|
})
|
||||||
app.ephemeralGC = ephemeralGC
|
app.ephemeralGC = ephemeralGC
|
||||||
|
|
||||||
var authProvider AuthProvider
|
var authProvider AuthProvider
|
||||||
|
|
||||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
oidcProvider, err := NewAuthProviderOIDC(
|
oidcProvider, err := NewAuthProviderOIDC(
|
||||||
ctx,
|
ctx,
|
||||||
&app,
|
&app,
|
||||||
@@ -177,17 +181,18 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
authProvider = oidcProvider
|
authProvider = oidcProvider
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
app.authProvider = authProvider
|
app.authProvider = authProvider
|
||||||
|
|
||||||
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
||||||
// TODO(kradalby): revisit why this takes a list.
|
// TODO(kradalby): revisit why this takes a list.
|
||||||
|
|
||||||
var magicDNSDomains []dnsname.FQDN
|
var magicDNSDomains []dnsname.FQDN
|
||||||
if cfg.PrefixV4 != nil {
|
if cfg.PrefixV4 != nil {
|
||||||
magicDNSDomains = append(
|
magicDNSDomains = append(
|
||||||
magicDNSDomains,
|
magicDNSDomains,
|
||||||
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.PrefixV6 != nil {
|
if cfg.PrefixV6 != nil {
|
||||||
magicDNSDomains = append(
|
magicDNSDomains = append(
|
||||||
magicDNSDomains,
|
magicDNSDomains,
|
||||||
@@ -198,6 +203,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
if app.cfg.TailcfgDNSConfig.Routes == nil {
|
if app.cfg.TailcfgDNSConfig.Routes == nil {
|
||||||
app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, d := range magicDNSDomains {
|
for _, d := range magicDNSDomains {
|
||||||
app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||||
}
|
}
|
||||||
@@ -206,7 +212,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
if cfg.DERP.ServerEnabled {
|
if cfg.DERP.ServerEnabled {
|
||||||
derpServerKey, err := readOrCreatePrivateKey(cfg.DERP.ServerPrivateKeyPath)
|
derpServerKey, err := readOrCreatePrivateKey(cfg.DERP.ServerPrivateKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read or create DERP server private key: %w", err)
|
return nil, fmt.Errorf("reading or creating DERP server private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if derpServerKey.Equal(*noisePrivateKey) {
|
if derpServerKey.Equal(*noisePrivateKey) {
|
||||||
@@ -232,6 +238,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.DERPServer = embeddedDERPServer
|
app.DERPServer = embeddedDERPServer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
lastExpiryCheck := time.Unix(0, 0)
|
lastExpiryCheck := time.Unix(0, 0)
|
||||||
|
|
||||||
derpTickerChan := make(<-chan time.Time)
|
derpTickerChan := make(<-chan time.Time)
|
||||||
|
|
||||||
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
||||||
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
||||||
defer derpTicker.Stop()
|
defer derpTicker.Stop()
|
||||||
|
|
||||||
derpTickerChan = derpTicker.C
|
derpTickerChan = derpTicker.C
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
|
|
||||||
case <-expireTicker.C:
|
case <-expireTicker.C:
|
||||||
var expiredNodeChanges []change.Change
|
var (
|
||||||
var changed bool
|
expiredNodeChanges []change.Change
|
||||||
|
changed bool
|
||||||
|
)
|
||||||
|
|
||||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||||
|
|
||||||
@@ -286,12 +297,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case <-derpTickerChan:
|
case <-derpTickerChan:
|
||||||
log.Info().Msg("Fetching DERPMap updates")
|
log.Info().Msg("fetching DERPMap updates")
|
||||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
|
|
||||||
|
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { //nolint:contextcheck
|
||||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||||
region, _ := h.DERPServer.GenerateRegion()
|
region, _ := h.DERPServer.GenerateRegion()
|
||||||
derpMap.Regions[region.RegionID] = ®ion
|
derpMap.Regions[region.RegionID] = ®ion
|
||||||
@@ -303,6 +316,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
h.state.SetDERPMap(derpMap)
|
h.state.SetDERPMap(derpMap)
|
||||||
|
|
||||||
h.Change(change.DERPMap())
|
h.Change(change.DERPMap())
|
||||||
@@ -311,6 +325,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||||
|
|
||||||
h.Change(change.ExtraRecords())
|
h.Change(change.ExtraRecords())
|
||||||
@@ -339,7 +354,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||||||
if !ok {
|
if !ok {
|
||||||
return ctx, status.Errorf(
|
return ctx, status.Errorf(
|
||||||
codes.InvalidArgument,
|
codes.InvalidArgument,
|
||||||
"Retrieving metadata is failed",
|
"retrieving metadata",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -347,7 +362,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||||||
if !ok {
|
if !ok {
|
||||||
return ctx, status.Errorf(
|
return ctx, status.Errorf(
|
||||||
codes.Unauthenticated,
|
codes.Unauthenticated,
|
||||||
"Authorization token is not supplied",
|
"authorization token not supplied",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,7 +377,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
|||||||
|
|
||||||
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ctx, status.Error(codes.Internal, "failed to validate token")
|
return ctx, status.Error(codes.Internal, "validating token")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !valid {
|
if !valid {
|
||||||
@@ -390,7 +405,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
|
|
||||||
writeUnauthorized := func(statusCode int) {
|
writeUnauthorized := func(statusCode int) {
|
||||||
writer.WriteHeader(statusCode)
|
writer.WriteHeader(statusCode)
|
||||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
|
|
||||||
|
if _, err := writer.Write([]byte("Unauthorized")); err != nil { //nolint:noinlineerr
|
||||||
log.Error().Err(err).Msg("writing HTTP response failed")
|
log.Error().Err(err).Msg("writing HTTP response failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -401,6 +417,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
Str("client_address", req.RemoteAddr).
|
Str("client_address", req.RemoteAddr).
|
||||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||||
writeUnauthorized(http.StatusUnauthorized)
|
writeUnauthorized(http.StatusUnauthorized)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -412,6 +429,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
Str("client_address", req.RemoteAddr).
|
Str("client_address", req.RemoteAddr).
|
||||||
Msg("failed to validate token")
|
Msg("failed to validate token")
|
||||||
writeUnauthorized(http.StatusUnauthorized)
|
writeUnauthorized(http.StatusUnauthorized)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,6 +438,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
Str("client_address", req.RemoteAddr).
|
Str("client_address", req.RemoteAddr).
|
||||||
Msg("invalid token")
|
Msg("invalid token")
|
||||||
writeUnauthorized(http.StatusUnauthorized)
|
writeUnauthorized(http.StatusUnauthorized)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -431,7 +450,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||||||
// and will remove it if it is not.
|
// and will remove it if it is not.
|
||||||
func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
||||||
// File does not exist, all fine
|
// File does not exist, all fine
|
||||||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
|
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { //nolint:noinlineerr
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,6 +474,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||||
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
||||||
}
|
}
|
||||||
|
|
||||||
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
||||||
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
||||||
Methods(http.MethodGet)
|
Methods(http.MethodGet)
|
||||||
@@ -484,8 +504,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
|
//
|
||||||
|
//nolint:gocyclo // complex server startup function
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
capver.CanOldCodeBeCleanedUp()
|
capver.CanOldCodeBeCleanedUp()
|
||||||
|
|
||||||
if profilingEnabled {
|
if profilingEnabled {
|
||||||
@@ -506,12 +529,13 @@ func (h *Headscale) Serve() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
versionInfo := types.GetVersionInfo()
|
versionInfo := types.GetVersionInfo()
|
||||||
log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("Starting Headscale")
|
log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("starting headscale")
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||||
Msg("Clients with a lower minimum version will be rejected")
|
Msg("Clients with a lower minimum version will be rejected")
|
||||||
|
|
||||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||||
|
|
||||||
h.mapBatcher.Start()
|
h.mapBatcher.Start()
|
||||||
defer h.mapBatcher.Close()
|
defer h.mapBatcher.Close()
|
||||||
|
|
||||||
@@ -526,7 +550,7 @@ func (h *Headscale) Serve() error {
|
|||||||
|
|
||||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get DERPMap: %w", err)
|
return fmt.Errorf("getting DERPMap: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||||
@@ -545,6 +569,7 @@ func (h *Headscale) Serve() error {
|
|||||||
// around between restarts, they will reconnect and the GC will
|
// around between restarts, they will reconnect and the GC will
|
||||||
// be cancelled.
|
// be cancelled.
|
||||||
go h.ephemeralGC.Start()
|
go h.ephemeralGC.Start()
|
||||||
|
|
||||||
ephmNodes := h.state.ListEphemeralNodes()
|
ephmNodes := h.state.ListEphemeralNodes()
|
||||||
for _, node := range ephmNodes.All() {
|
for _, node := range ephmNodes.All() {
|
||||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
@@ -555,7 +580,9 @@ func (h *Headscale) Serve() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
||||||
|
|
||||||
go h.extraRecordMan.Run()
|
go h.extraRecordMan.Run()
|
||||||
defer h.extraRecordMan.Close()
|
defer h.extraRecordMan.Close()
|
||||||
}
|
}
|
||||||
@@ -564,6 +591,7 @@ func (h *Headscale) Serve() error {
|
|||||||
// records updates
|
// records updates
|
||||||
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
||||||
defer scheduleCancel()
|
defer scheduleCancel()
|
||||||
|
|
||||||
go h.scheduledTasks(scheduleCtx)
|
go h.scheduledTasks(scheduleCtx)
|
||||||
|
|
||||||
if zl.GlobalLevel() == zl.TraceLevel {
|
if zl.GlobalLevel() == zl.TraceLevel {
|
||||||
@@ -576,6 +604,7 @@ func (h *Headscale) Serve() error {
|
|||||||
errorGroup := new(errgroup.Group)
|
errorGroup := new(errgroup.Group)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -586,29 +615,30 @@ func (h *Headscale) Serve() error {
|
|||||||
|
|
||||||
err = h.ensureUnixSocketIsAbsent()
|
err = h.ensureUnixSocketIsAbsent()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to remove old socket file: %w", err)
|
return fmt.Errorf("removing old socket file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
socketDir := filepath.Dir(h.cfg.UnixSocket)
|
socketDir := filepath.Dir(h.cfg.UnixSocket)
|
||||||
|
|
||||||
err = util.EnsureDir(socketDir)
|
err = util.EnsureDir(socketDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting up unix socket: %w", err)
|
return fmt.Errorf("setting up unix socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
socketListener, err := net.Listen("unix", h.cfg.UnixSocket)
|
socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set up gRPC socket: %w", err)
|
return fmt.Errorf("setting up gRPC socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Change socket permissions
|
// Change socket permissions
|
||||||
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil {
|
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("failed change permission of gRPC socket: %w", err)
|
return fmt.Errorf("changing gRPC socket permission: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
grpcGatewayMux := grpcRuntime.NewServeMux()
|
grpcGatewayMux := grpcRuntime.NewServeMux()
|
||||||
|
|
||||||
// Make the grpc-gateway connect to grpc over socket
|
// Make the grpc-gateway connect to grpc over socket
|
||||||
grpcGatewayConn, err := grpc.Dial(
|
grpcGatewayConn, err := grpc.Dial( //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||||
h.cfg.UnixSocket,
|
h.cfg.UnixSocket,
|
||||||
[]grpc.DialOption{
|
[]grpc.DialOption{
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
@@ -659,10 +689,13 @@ func (h *Headscale) Serve() error {
|
|||||||
// https://github.com/soheilhy/cmux/issues/68
|
// https://github.com/soheilhy/cmux/issues/68
|
||||||
// https://github.com/soheilhy/cmux/issues/91
|
// https://github.com/soheilhy/cmux/issues/91
|
||||||
|
|
||||||
var grpcServer *grpc.Server
|
var (
|
||||||
var grpcListener net.Listener
|
grpcServer *grpc.Server
|
||||||
|
grpcListener net.Listener
|
||||||
|
)
|
||||||
|
|
||||||
if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
|
if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
|
||||||
log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr)
|
log.Info().Msgf("enabling remote gRPC at %s", h.cfg.GRPCAddr)
|
||||||
|
|
||||||
grpcOptions := []grpc.ServerOption{
|
grpcOptions := []grpc.ServerOption{
|
||||||
grpc.ChainUnaryInterceptor(
|
grpc.ChainUnaryInterceptor(
|
||||||
@@ -685,9 +718,9 @@ func (h *Headscale) Serve() error {
|
|||||||
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
|
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
|
||||||
reflection.Register(grpcServer)
|
reflection.Register(grpcServer)
|
||||||
|
|
||||||
grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
|
grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
return fmt.Errorf("binding to TCP address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
|
||||||
@@ -715,14 +748,16 @@ func (h *Headscale) Serve() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var httpListener net.Listener
|
var httpListener net.Listener
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
httpServer.TLSConfig = tlsConfig
|
httpServer.TLSConfig = tlsConfig
|
||||||
httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
|
httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
|
||||||
} else {
|
} else {
|
||||||
httpListener, err = net.Listen("tcp", h.cfg.Addr)
|
httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
return fmt.Errorf("binding to TCP address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
|
||||||
@@ -738,7 +773,7 @@ func (h *Headscale) Serve() error {
|
|||||||
if h.cfg.MetricsAddr != "" {
|
if h.cfg.MetricsAddr != "" {
|
||||||
debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr)
|
debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bind to TCP address: %w", err)
|
return fmt.Errorf("binding to TCP address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
debugHTTPServer = h.debugHTTPServer()
|
debugHTTPServer = h.debugHTTPServer()
|
||||||
@@ -751,19 +786,24 @@ func (h *Headscale) Serve() error {
|
|||||||
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var tailsqlContext context.Context
|
var tailsqlContext context.Context
|
||||||
|
|
||||||
if tailsqlEnabled {
|
if tailsqlEnabled {
|
||||||
if h.cfg.Database.Type != types.DatabaseSqlite {
|
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||||
|
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
||||||
log.Fatal().
|
log.Fatal().
|
||||||
Str("type", h.cfg.Database.Type).
|
Str("type", h.cfg.Database.Type).
|
||||||
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tailsqlTSKey == "" {
|
if tailsqlTSKey == "" {
|
||||||
|
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
||||||
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
||||||
}
|
}
|
||||||
|
|
||||||
tailsqlContext = context.Background()
|
tailsqlContext = context.Background()
|
||||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
|
|
||||||
|
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) //nolint:errcheck
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle common process-killing signals so we can gracefully shut down:
|
// Handle common process-killing signals so we can gracefully shut down:
|
||||||
@@ -774,6 +814,7 @@ func (h *Headscale) Serve() error {
|
|||||||
syscall.SIGTERM,
|
syscall.SIGTERM,
|
||||||
syscall.SIGQUIT,
|
syscall.SIGQUIT,
|
||||||
syscall.SIGHUP)
|
syscall.SIGHUP)
|
||||||
|
|
||||||
sigFunc := func(c chan os.Signal) {
|
sigFunc := func(c chan os.Signal) {
|
||||||
// Wait for a SIGINT or SIGKILL:
|
// Wait for a SIGINT or SIGKILL:
|
||||||
for {
|
for {
|
||||||
@@ -798,6 +839,7 @@ func (h *Headscale) Serve() error {
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
info := func(msg string) { log.Info().Msg(msg) }
|
info := func(msg string) { log.Info().Msg(msg) }
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("signal", sig.String()).
|
Str("signal", sig.String()).
|
||||||
Msg("Received signal to stop, shutting down gracefully")
|
Msg("Received signal to stop, shutting down gracefully")
|
||||||
@@ -854,6 +896,7 @@ func (h *Headscale) Serve() error {
|
|||||||
if debugHTTPListener != nil {
|
if debugHTTPListener != nil {
|
||||||
debugHTTPListener.Close()
|
debugHTTPListener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
httpListener.Close()
|
httpListener.Close()
|
||||||
grpcGatewayConn.Close()
|
grpcGatewayConn.Close()
|
||||||
|
|
||||||
@@ -863,6 +906,7 @@ func (h *Headscale) Serve() error {
|
|||||||
|
|
||||||
// Close state connections
|
// Close state connections
|
||||||
info("closing state and database")
|
info("closing state and database")
|
||||||
|
|
||||||
err = h.state.Close()
|
err = h.state.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to close state")
|
log.Error().Err(err).Msg("failed to close state")
|
||||||
@@ -875,6 +919,7 @@ func (h *Headscale) Serve() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
errorGroup.Go(func() error {
|
errorGroup.Go(func() error {
|
||||||
sigFunc(sigc)
|
sigFunc(sigc)
|
||||||
|
|
||||||
@@ -886,6 +931,7 @@ func (h *Headscale) Serve() error {
|
|||||||
|
|
||||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if h.cfg.TLS.LetsEncrypt.Hostname != "" {
|
if h.cfg.TLS.LetsEncrypt.Hostname != "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
@@ -918,7 +964,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||||||
// Configuration via autocert with HTTP-01. This requires listening on
|
// Configuration via autocert with HTTP-01. This requires listening on
|
||||||
// port 80 for the certificate validation in addition to the headscale
|
// port 80 for the certificate validation in addition to the headscale
|
||||||
// service, which can be configured to run on any other port.
|
// service, which can be configured to run on any other port.
|
||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
||||||
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
||||||
@@ -940,13 +985,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||||||
}
|
}
|
||||||
} else if h.cfg.TLS.CertPath == "" {
|
} else if h.cfg.TLS.CertPath == "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||||
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
log.Warn().Msg("listening without TLS but ServerURL does not start with http://")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
log.Warn().Msg("listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
@@ -963,6 +1008,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||||||
|
|
||||||
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||||
dir := filepath.Dir(path)
|
dir := filepath.Dir(path)
|
||||||
|
|
||||||
err := util.EnsureDir(dir)
|
err := util.EnsureDir(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("ensuring private key directory: %w", err)
|
return nil, fmt.Errorf("ensuring private key directory: %w", err)
|
||||||
@@ -970,21 +1016,22 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||||||
|
|
||||||
privateKey, err := os.ReadFile(path)
|
privateKey, err := os.ReadFile(path)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
log.Info().Str("path", path).Msg("No private key file at path, creating...")
|
log.Info().Str("path", path).Msg("no private key file at path, creating...")
|
||||||
|
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
|
|
||||||
machineKeyStr, err := machineKey.MarshalText()
|
machineKeyStr, err := machineKey.MarshalText()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to convert private key to string for saving: %w",
|
"converting private key to string for saving: %w",
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
|
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to save private key to disk at path %q: %w",
|
"saving private key to disk at path %q: %w",
|
||||||
path,
|
path,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
@@ -992,14 +1039,14 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||||||
|
|
||||||
return &machineKey, nil
|
return &machineKey, nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read private key file: %w", err)
|
return nil, fmt.Errorf("reading private key file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||||||
|
|
||||||
var machineKey key.MachinePrivate
|
var machineKey key.MachinePrivate
|
||||||
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil {
|
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { //nolint:noinlineerr
|
||||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
return nil, fmt.Errorf("parsing private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &machineKey, nil
|
return &machineKey, nil
|
||||||
@@ -1023,7 +1070,7 @@ type acmeLogger struct {
|
|||||||
func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
resp, err := l.rt.RoundTrip(req)
|
resp, err := l.rt.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Str("url", req.URL.String()).Msg("ACME request failed")
|
log.Error().Err(err).Str("url", req.URL.String()).Msg("acme request failed")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1031,7 +1078,7 @@ func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("ACME request returned error")
|
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("acme request returned error")
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
|
|||||||
@@ -16,12 +16,11 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthProvider interface {
|
type AuthProvider interface {
|
||||||
RegisterHandler(http.ResponseWriter, *http.Request)
|
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
||||||
AuthURL(types.RegistrationID) string
|
AuthURL(regID types.RegistrationID) string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleRegister(
|
func (h *Headscale) handleRegister(
|
||||||
@@ -42,8 +41,7 @@ func (h *Headscale) handleRegister(
|
|||||||
// This is a logout attempt (expiry in the past)
|
// This is a logout attempt (expiry in the past)
|
||||||
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
|
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname()).
|
|
||||||
Bool("is_ephemeral", node.IsEphemeral()).
|
Bool("is_ephemeral", node.IsEphemeral()).
|
||||||
Bool("has_authkey", node.AuthKey().Valid()).
|
Bool("has_authkey", node.AuthKey().Valid()).
|
||||||
Msg("Found existing node for logout, calling handleLogout")
|
Msg("Found existing node for logout, calling handleLogout")
|
||||||
@@ -52,6 +50,7 @@ func (h *Headscale) handleRegister(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("handling logout: %w", err)
|
return nil, fmt.Errorf("handling logout: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -113,8 +112,7 @@ func (h *Headscale) handleRegister(
|
|||||||
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
|
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||||
var httpErr HTTPError
|
if httpErr, ok := errors.AsType[HTTPError](err); ok {
|
||||||
if errors.As(err, &httpErr) {
|
|
||||||
return nil, httpErr
|
return nil, httpErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +131,7 @@ func (h *Headscale) handleRegister(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
||||||
// logout attempt from a node. If the node is not attempting to
|
// logout attempt from a node. If the node is not attempting to.
|
||||||
func (h *Headscale) handleLogout(
|
func (h *Headscale) handleLogout(
|
||||||
node types.NodeView,
|
node types.NodeView,
|
||||||
req tailcfg.RegisterRequest,
|
req tailcfg.RegisterRequest,
|
||||||
@@ -155,11 +153,12 @@ func (h *Headscale) handleLogout(
|
|||||||
// force the client to re-authenticate.
|
// force the client to re-authenticate.
|
||||||
// TODO(kradalby): I wonder if this is a path we ever hit?
|
// TODO(kradalby): I wonder if this is a path we ever hit?
|
||||||
if node.IsExpired() {
|
if node.IsExpired() {
|
||||||
log.Trace().Str("node.name", node.Hostname()).
|
log.Trace().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Interface("reg.req", req).
|
Interface("reg.req", req).
|
||||||
Bool("unexpected", true).
|
Bool("unexpected", true).
|
||||||
Msg("Node key expired, forcing re-authentication")
|
Msg("Node key expired, forcing re-authentication")
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
NodeKeyExpired: true,
|
NodeKeyExpired: true,
|
||||||
MachineAuthorized: false,
|
MachineAuthorized: false,
|
||||||
@@ -182,8 +181,7 @@ func (h *Headscale) handleLogout(
|
|||||||
// Zero expiry is handled in handleRegister() before calling this function.
|
// Zero expiry is handled in handleRegister() before calling this function.
|
||||||
if req.Expiry.Before(time.Now()) {
|
if req.Expiry.Before(time.Now()) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname()).
|
|
||||||
Bool("is_ephemeral", node.IsEphemeral()).
|
Bool("is_ephemeral", node.IsEphemeral()).
|
||||||
Bool("has_authkey", node.AuthKey().Valid()).
|
Bool("has_authkey", node.AuthKey().Valid()).
|
||||||
Time("req.expiry", req.Expiry).
|
Time("req.expiry", req.Expiry).
|
||||||
@@ -191,8 +189,7 @@ func (h *Headscale) handleLogout(
|
|||||||
|
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
log.Info().
|
log.Info().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname()).
|
|
||||||
Msg("Deleting ephemeral node during logout")
|
Msg("Deleting ephemeral node during logout")
|
||||||
|
|
||||||
c, err := h.state.DeleteNode(node)
|
c, err := h.state.DeleteNode(node)
|
||||||
@@ -209,8 +206,7 @@ func (h *Headscale) handleLogout(
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname()).
|
|
||||||
Msg("Node is not ephemeral, setting expiry instead of deleting")
|
Msg("Node is not ephemeral, setting expiry instead of deleting")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,6 +275,7 @@ func (h *Headscale) waitForFollowup(
|
|||||||
// registration is expired in the cache, instruct the client to try a new registration
|
// registration is expired in the cache, instruct the client to try a new registration
|
||||||
return h.reqToNewRegisterResponse(req, machineKey)
|
return h.reqToNewRegisterResponse(req, machineKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodeToRegisterResponse(node.View()), nil
|
return nodeToRegisterResponse(node.View()), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -316,7 +313,7 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||||||
MachineKey: machineKey,
|
MachineKey: machineKey,
|
||||||
NodeKey: req.NodeKey,
|
NodeKey: req.NodeKey,
|
||||||
Hostinfo: hostinfo,
|
Hostinfo: hostinfo,
|
||||||
LastSeen: ptr.To(time.Now()),
|
LastSeen: new(time.Now()),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -324,7 +321,7 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||||||
nodeToRegister.Node.Expiry = &req.Expiry
|
nodeToRegister.Node.Expiry = &req.Expiry
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("New followup node registration using key: %s", newRegID)
|
log.Info().Msgf("new followup node registration using key: %s", newRegID)
|
||||||
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
@@ -344,8 +341,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||||
}
|
}
|
||||||
var perr types.PAKError
|
|
||||||
if errors.As(err, &perr) {
|
if perr, ok := errors.AsType[types.PAKError](err); ok {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +352,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
// If node is not valid, it means an ephemeral node was deleted during logout
|
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||||
if !node.Valid() {
|
if !node.Valid() {
|
||||||
h.Change(changed)
|
h.Change(changed)
|
||||||
return nil, nil
|
return nil, nil //nolint:nilnil // intentional: no node to return when ephemeral deleted
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||||
@@ -397,8 +394,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
Caller().
|
Caller().
|
||||||
Interface("reg.resp", resp).
|
Interface("reg.resp", resp).
|
||||||
Interface("reg.req", req).
|
Interface("reg.req", req).
|
||||||
Str("node.name", node.Hostname()).
|
EmbedObject(node).
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
|
||||||
Msg("RegisterResponse")
|
Msg("RegisterResponse")
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -435,6 +431,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
Str("generated.hostname", hostname).
|
Str("generated.hostname", hostname).
|
||||||
Msg("Received registration request with empty hostname, generated default")
|
Msg("Received registration request with empty hostname, generated default")
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.Hostname = hostname
|
hostinfo.Hostname = hostname
|
||||||
|
|
||||||
nodeToRegister := types.NewRegisterNode(
|
nodeToRegister := types.NewRegisterNode(
|
||||||
@@ -443,7 +440,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
MachineKey: machineKey,
|
MachineKey: machineKey,
|
||||||
NodeKey: req.NodeKey,
|
NodeKey: req.NodeKey,
|
||||||
Hostinfo: hostinfo,
|
Hostinfo: hostinfo,
|
||||||
LastSeen: ptr.To(time.Now()),
|
LastSeen: new(time.Now()),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -456,7 +453,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
nodeToRegister,
|
nodeToRegister,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.Info().Msgf("Starting node registration using key: %s", registrationId)
|
log.Info().Msgf("starting node registration using key: %s", registrationId)
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
AuthURL: h.authProvider.AuthURL(registrationId),
|
AuthURL: h.authProvider.AuthURL(registrationId),
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,7 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
|||||||
"v1.88": 125,
|
"v1.88": 125,
|
||||||
"v1.90": 130,
|
"v1.90": 130,
|
||||||
"v1.92": 131,
|
"v1.92": 131,
|
||||||
|
"v1.94": 131,
|
||||||
}
|
}
|
||||||
|
|
||||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||||
|
|||||||
@@ -9,10 +9,9 @@ var tailscaleLatestMajorMinorTests = []struct {
|
|||||||
stripV bool
|
stripV bool
|
||||||
expected []string
|
expected []string
|
||||||
}{
|
}{
|
||||||
{3, false, []string{"v1.88", "v1.90", "v1.92"}},
|
{3, false, []string{"v1.90", "v1.92", "v1.94"}},
|
||||||
{2, true, []string{"1.90", "1.92"}},
|
{2, true, []string{"1.92", "1.94"}},
|
||||||
{10, true, []string{
|
{10, true, []string{
|
||||||
"1.74",
|
|
||||||
"1.76",
|
"1.76",
|
||||||
"1.78",
|
"1.78",
|
||||||
"1.80",
|
"1.80",
|
||||||
@@ -22,6 +21,7 @@ var tailscaleLatestMajorMinorTests = []struct {
|
|||||||
"1.88",
|
"1.88",
|
||||||
"1.90",
|
"1.90",
|
||||||
"1.92",
|
"1.92",
|
||||||
|
"1.94",
|
||||||
}},
|
}},
|
||||||
{0, false, nil},
|
{0, false, nil},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,8 +77,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
|||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hsdb.DB.Save(&key).Error; err != nil {
|
if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||||
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
return "", nil, fmt.Errorf("saving API key to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return keyStr, &key, nil
|
return keyStr, &key, nil
|
||||||
@@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
|||||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||||
keys := []types.APIKey{}
|
keys := []types.APIKey{}
|
||||||
if err := hsdb.DB.Find(&keys).Error; err != nil {
|
|
||||||
|
err := hsdb.DB.Find(&keys).Error
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
|||||||
|
|
||||||
// ExpireAPIKey marks a ApiKey as expired.
|
// ExpireAPIKey marks a ApiKey as expired.
|
||||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||||
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"tailscale.com/net/tsaddr"
|
|
||||||
"zgo.at/zcache/v2"
|
"zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,6 +52,8 @@ type HSDatabase struct {
|
|||||||
|
|
||||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||||
// It accepts the full configuration to allow migrations access to policy settings.
|
// It accepts the full configuration to allow migrations access to policy settings.
|
||||||
|
//
|
||||||
|
//nolint:gocyclo // complex database initialization with many migrations
|
||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
||||||
@@ -62,6 +63,11 @@ func NewHeadscaleDatabase(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = checkVersionUpgradePath(dbConn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("version check: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
migrations := gormigrate.New(
|
migrations := gormigrate.New(
|
||||||
dbConn,
|
dbConn,
|
||||||
gormigrate.DefaultOptions,
|
gormigrate.DefaultOptions,
|
||||||
@@ -76,7 +82,7 @@ func NewHeadscaleDatabase(
|
|||||||
ID: "202501221827",
|
ID: "202501221827",
|
||||||
Migrate: func(tx *gorm.DB) error {
|
Migrate: func(tx *gorm.DB) error {
|
||||||
// Remove any invalid routes associated with a node that does not exist.
|
// Remove any invalid routes associated with a node that does not exist.
|
||||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
|
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||||
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
|
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -84,14 +90,14 @@ func NewHeadscaleDatabase(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove any invalid routes without a node_id.
|
// Remove any invalid routes without a node_id.
|
||||||
if tx.Migrator().HasTable(&types.Route{}) {
|
if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||||
err := tx.Exec("delete from routes where node_id is null").Error
|
err := tx.Exec("delete from routes where node_id is null").Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := tx.AutoMigrate(&types.Route{})
|
err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("automigrating types.Route: %w", err)
|
return fmt.Errorf("automigrating types.Route: %w", err)
|
||||||
}
|
}
|
||||||
@@ -109,6 +115,7 @@ func NewHeadscaleDatabase(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
|
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.AutoMigrate(&types.Node{})
|
err = tx.AutoMigrate(&types.Node{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("automigrating types.Node: %w", err)
|
return fmt.Errorf("automigrating types.Node: %w", err)
|
||||||
@@ -155,7 +162,8 @@ AND auth_key_id NOT IN (
|
|||||||
|
|
||||||
nodeRoutes := map[uint64][]netip.Prefix{}
|
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||||
|
|
||||||
var routes []types.Route
|
var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations
|
||||||
|
|
||||||
err = tx.Find(&routes).Error
|
err = tx.Find(&routes).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fetching routes: %w", err)
|
return fmt.Errorf("fetching routes: %w", err)
|
||||||
@@ -168,10 +176,10 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
for nodeID, routes := range nodeRoutes {
|
for nodeID, routes := range nodeRoutes {
|
||||||
tsaddr.SortPrefixes(routes)
|
slices.SortFunc(routes, netip.Prefix.Compare)
|
||||||
routes = slices.Compact(routes)
|
routes = slices.Compact(routes)
|
||||||
|
|
||||||
data, err := json.Marshal(routes)
|
data, _ := json.Marshal(routes)
|
||||||
|
|
||||||
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -180,7 +188,7 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Drop the old table.
|
// Drop the old table.
|
||||||
_ = tx.Migrator().DropTable(&types.Route{})
|
_ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -245,21 +253,24 @@ AND auth_key_id NOT IN (
|
|||||||
Migrate: func(tx *gorm.DB) error {
|
Migrate: func(tx *gorm.DB) error {
|
||||||
// Only run on SQLite
|
// Only run on SQLite
|
||||||
if cfg.Database.Type != types.DatabaseSqlite {
|
if cfg.Database.Type != types.DatabaseSqlite {
|
||||||
log.Info().Msg("Skipping schema migration on non-SQLite database")
|
log.Info().Msg("skipping schema migration on non-SQLite database")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msg("Starting schema recreation with table renaming")
|
log.Info().Msg("starting schema recreation with table renaming")
|
||||||
|
|
||||||
// Rename existing tables to _old versions
|
// Rename existing tables to _old versions
|
||||||
tablesToRename := []string{"users", "pre_auth_keys", "api_keys", "nodes", "policies"}
|
tablesToRename := []string{"users", "pre_auth_keys", "api_keys", "nodes", "policies"}
|
||||||
|
|
||||||
// Check if routes table exists and drop it (should have been migrated already)
|
// Check if routes table exists and drop it (should have been migrated already)
|
||||||
var routesExists bool
|
var routesExists bool
|
||||||
|
|
||||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
||||||
if err == nil && routesExists {
|
if err == nil && routesExists {
|
||||||
log.Info().Msg("Dropping leftover routes table")
|
log.Info().Msg("dropping leftover routes table")
|
||||||
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
|
|
||||||
|
err := tx.Exec("DROP TABLE routes").Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("dropping routes table: %w", err)
|
return fmt.Errorf("dropping routes table: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -281,6 +292,7 @@ AND auth_key_id NOT IN (
|
|||||||
for _, table := range tablesToRename {
|
for _, table := range tablesToRename {
|
||||||
// Check if table exists before renaming
|
// Check if table exists before renaming
|
||||||
var exists bool
|
var exists bool
|
||||||
|
|
||||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
||||||
@@ -291,7 +303,8 @@ AND auth_key_id NOT IN (
|
|||||||
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||||
|
|
||||||
// Rename current table to _old
|
// Rename current table to _old
|
||||||
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
|
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
|
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -365,7 +378,8 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, createSQL := range tableCreationSQL {
|
for _, createSQL := range tableCreationSQL {
|
||||||
if err := tx.Exec(createSQL).Error; err != nil {
|
err := tx.Exec(createSQL).Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("creating new table: %w", err)
|
return fmt.Errorf("creating new table: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -394,7 +408,8 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, copySQL := range dataCopySQL {
|
for _, copySQL := range dataCopySQL {
|
||||||
if err := tx.Exec(copySQL).Error; err != nil {
|
err := tx.Exec(copySQL).Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("copying data: %w", err)
|
return fmt.Errorf("copying data: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -417,19 +432,21 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, indexSQL := range indexes {
|
for _, indexSQL := range indexes {
|
||||||
if err := tx.Exec(indexSQL).Error; err != nil {
|
err := tx.Exec(indexSQL).Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("creating index: %w", err)
|
return fmt.Errorf("creating index: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop old tables only after everything succeeds
|
// Drop old tables only after everything succeeds
|
||||||
for _, table := range tablesToRename {
|
for _, table := range tablesToRename {
|
||||||
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
|
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||||
log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded")
|
if err != nil {
|
||||||
|
log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msg("Schema recreation completed successfully")
|
log.Info().Msg("schema recreation completed successfully")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -595,12 +612,12 @@ AND auth_key_id NOT IN (
|
|||||||
// 1. Load policy from file or database based on configuration
|
// 1. Load policy from file or database based on configuration
|
||||||
policyData, err := PolicyBytes(tx, cfg)
|
policyData, err := PolicyBytes(tx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
log.Warn().Err(err).Msg("failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(policyData) == 0 {
|
if len(policyData) == 0 {
|
||||||
log.Info().Msg("No policy found, skipping RequestTags migration (tags will be validated on node reconnect)")
|
log.Info().Msg("no policy found, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,7 +635,7 @@ AND auth_key_id NOT IN (
|
|||||||
// 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.)
|
// 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.)
|
||||||
polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice())
|
polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
log.Warn().Err(err).Msg("failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -652,8 +669,7 @@ AND auth_key_id NOT IN (
|
|||||||
if len(validatedTags) == 0 {
|
if len(validatedTags) == 0 {
|
||||||
if len(rejectedTags) > 0 {
|
if len(rejectedTags) > 0 {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Uint64("node.id", uint64(node.ID)).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname).
|
|
||||||
Strs("rejected_tags", rejectedTags).
|
Strs("rejected_tags", rejectedTags).
|
||||||
Msg("RequestTags rejected during migration (not authorized)")
|
Msg("RequestTags rejected during migration (not authorized)")
|
||||||
}
|
}
|
||||||
@@ -676,8 +692,7 @@ AND auth_key_id NOT IN (
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Uint64("node.id", uint64(node.ID)).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname).
|
|
||||||
Strs("validated_tags", validatedTags).
|
Strs("validated_tags", validatedTags).
|
||||||
Strs("rejected_tags", rejectedTags).
|
Strs("rejected_tags", rejectedTags).
|
||||||
Strs("existing_tags", existingTags).
|
Strs("existing_tags", existingTags).
|
||||||
@@ -750,6 +765,20 @@ AND auth_key_id NOT IN (
|
|||||||
return nil, fmt.Errorf("migration failed: %w", err)
|
return nil, fmt.Errorf("migration failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the current version in the database after migrations succeed.
|
||||||
|
// Dev builds skip this to preserve the stored version for the next
|
||||||
|
// real versioned binary.
|
||||||
|
currentVersion := types.GetVersionInfo().Version
|
||||||
|
if !isDev(currentVersion) {
|
||||||
|
err = setDatabaseVersion(dbConn, currentVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"storing database version: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate that the schema ends up in the expected state.
|
// Validate that the schema ends up in the expected state.
|
||||||
// This is currently only done on sqlite as squibble does not
|
// This is currently only done on sqlite as squibble does not
|
||||||
// support Postgres and we use our sqlite schema as our source of
|
// support Postgres and we use our sqlite schema as our source of
|
||||||
@@ -762,6 +791,7 @@ AND auth_key_id NOT IN (
|
|||||||
|
|
||||||
// or else it blocks...
|
// or else it blocks...
|
||||||
sqlConn.SetMaxIdleConns(maxIdleConns)
|
sqlConn.SetMaxIdleConns(maxIdleConns)
|
||||||
|
|
||||||
sqlConn.SetMaxOpenConns(maxOpenConns)
|
sqlConn.SetMaxOpenConns(maxOpenConns)
|
||||||
defer sqlConn.SetMaxIdleConns(1)
|
defer sqlConn.SetMaxIdleConns(1)
|
||||||
defer sqlConn.SetMaxOpenConns(1)
|
defer sqlConn.SetMaxOpenConns(1)
|
||||||
@@ -779,7 +809,7 @@ AND auth_key_id NOT IN (
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
|
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr
|
||||||
return nil, fmt.Errorf("validating schema: %w", err)
|
return nil, fmt.Errorf("validating schema: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -805,6 +835,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
|||||||
switch cfg.Type {
|
switch cfg.Type {
|
||||||
case types.DatabaseSqlite:
|
case types.DatabaseSqlite:
|
||||||
dir := filepath.Dir(cfg.Sqlite.Path)
|
dir := filepath.Dir(cfg.Sqlite.Path)
|
||||||
|
|
||||||
err := util.EnsureDir(dir)
|
err := util.EnsureDir(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
||||||
@@ -858,7 +889,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
|||||||
Str("path", dbString).
|
Str("path", dbString).
|
||||||
Msg("Opening database")
|
Msg("Opening database")
|
||||||
|
|
||||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr
|
||||||
if !sslEnabled {
|
if !sslEnabled {
|
||||||
dbString += " sslmode=disable"
|
dbString += " sslmode=disable"
|
||||||
}
|
}
|
||||||
@@ -913,7 +944,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||||||
|
|
||||||
// Get the current foreign key status
|
// Get the current foreign key status
|
||||||
var fkOriginallyEnabled int
|
var fkOriginallyEnabled int
|
||||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
|
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("checking foreign key status: %w", err)
|
return fmt.Errorf("checking foreign key status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -937,33 +968,36 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, migrationID := range migrationIDs {
|
for _, migrationID := range migrationIDs {
|
||||||
log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration")
|
log.Trace().Caller().Str("migration_id", migrationID).Msg("running migration")
|
||||||
needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
|
needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
|
||||||
|
|
||||||
if needsFKDisabled {
|
if needsFKDisabled {
|
||||||
// Disable foreign keys for this migration
|
// Disable foreign keys for this migration
|
||||||
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
|
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
|
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Ensure foreign keys are enabled for this migration
|
// Ensure foreign keys are enabled for this migration
|
||||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
|
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run up to this specific migration (will only run the next pending migration)
|
// Run up to this specific migration (will only run the next pending migration)
|
||||||
if err := migrations.MigrateTo(migrationID); err != nil {
|
err := migrations.MigrateTo(migrationID)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("running migration %s: %w", migrationID, err)
|
return fmt.Errorf("running migration %s: %w", migrationID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("restoring foreign keys: %w", err)
|
return fmt.Errorf("restoring foreign keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the rest of the migrations
|
// Run the rest of the migrations
|
||||||
if err := migrations.Migrate(); err != nil {
|
if err := migrations.Migrate(); err != nil { //nolint:noinlineerr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -981,16 +1015,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var violation constraintViolation
|
var violation constraintViolation
|
||||||
if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil {
|
|
||||||
|
err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
violatedConstraints = append(violatedConstraints, violation)
|
violatedConstraints = append(violatedConstraints, violation)
|
||||||
}
|
}
|
||||||
_ = rows.Close()
|
|
||||||
|
if err := rows.Err(); err != nil { //nolint:noinlineerr
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if len(violatedConstraints) > 0 {
|
if len(violatedConstraints) > 0 {
|
||||||
for _, violation := range violatedConstraints {
|
for _, violation := range violatedConstraints {
|
||||||
@@ -1005,7 +1045,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// PostgreSQL can run all migrations in one block - no foreign key issues
|
// PostgreSQL can run all migrations in one block - no foreign key issues
|
||||||
if err := migrations.Migrate(); err != nil {
|
err := migrations.Migrate()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1016,6 +1057,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sqlDB, err := hsdb.DB.DB()
|
sqlDB, err := hsdb.DB.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1031,7 +1073,7 @@ func (hsdb *HSDatabase) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
||||||
db.Exec("VACUUM")
|
db.Exec("VACUUM") //nolint:errcheck,noctx
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.Close()
|
return db.Close()
|
||||||
@@ -1040,12 +1082,14 @@ func (hsdb *HSDatabase) Close() error {
|
|||||||
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
||||||
rx := hsdb.DB.Begin()
|
rx := hsdb.DB.Begin()
|
||||||
defer rx.Rollback()
|
defer rx.Rollback()
|
||||||
|
|
||||||
return fn(rx)
|
return fn(rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||||
rx := db.Begin()
|
rx := db.Begin()
|
||||||
defer rx.Rollback()
|
defer rx.Rollback()
|
||||||
|
|
||||||
ret, err := fn(rx)
|
ret, err := fn(rx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var no T
|
var no T
|
||||||
@@ -1058,7 +1102,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
|||||||
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||||
tx := hsdb.DB.Begin()
|
tx := hsdb.DB.Begin()
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if err := fn(tx); err != nil {
|
|
||||||
|
err := fn(tx)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1068,6 +1114,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
|||||||
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
||||||
tx := db.Begin()
|
tx := db.Begin()
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
ret, err := fn(tx)
|
ret, err := fn(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var no T
|
var no T
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
|||||||
|
|
||||||
// Verify api_keys data preservation
|
// Verify api_keys data preservation
|
||||||
var apiKeyCount int
|
var apiKeyCount int
|
||||||
|
|
||||||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||||||
@@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(string(schemaContent))
|
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
|||||||
func requireConstraintFailed(t *testing.T, err error) {
|
func requireConstraintFailed(t *testing.T, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||||
}
|
}
|
||||||
@@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "no-duplicate-username-if-no-oidc",
|
name: "no-duplicate-username-if-no-oidc",
|
||||||
run: func(t *testing.T, db *gorm.DB) {
|
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||||
_, err := CreateUser(db, types.User{Name: "user1"})
|
_, err := CreateUser(db, types.User{Name: "user1"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = CreateUser(db, types.User{Name: "user1"})
|
_, err = CreateUser(db, types.User{Name: "user1"})
|
||||||
@@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no-oidc-duplicate-username-and-id",
|
name: "no-oidc-duplicate-username-and-id",
|
||||||
run: func(t *testing.T, db *gorm.DB) {
|
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||||
user := types.User{
|
user := types.User{
|
||||||
Model: gorm.Model{ID: 1},
|
Model: gorm.Model{ID: 1},
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
@@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no-oidc-duplicate-id",
|
name: "no-oidc-duplicate-id",
|
||||||
run: func(t *testing.T, db *gorm.DB) {
|
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||||
user := types.User{
|
user := types.User{
|
||||||
Model: gorm.Model{ID: 1},
|
Model: gorm.Model{ID: 1},
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
@@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "allow-duplicate-username-cli-then-oidc",
|
name: "allow-duplicate-username-cli-then-oidc",
|
||||||
run: func(t *testing.T, db *gorm.DB) {
|
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||||
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "allow-duplicate-username-oidc-then-cli",
|
name: "allow-duplicate-username-oidc-then-cli",
|
||||||
run: func(t *testing.T, db *gorm.DB) {
|
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||||
user := types.User{
|
user := types.User{
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||||||
@@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Construct the pg_restore command
|
// Construct the pg_restore command
|
||||||
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||||
|
|
||||||
// Set the output streams
|
// Set the output streams
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
@@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
|||||||
// skip already-applied migrations and only run new ones.
|
// skip already-applied migrations and only run new ones.
|
||||||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
schemas, err := os.ReadDir("testdata/sqlite")
|
schemas, err := os.ReadDir("testdata/sqlite")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||||
|
|
||||||
// Basic deletion tracking mechanism
|
// Basic deletion tracking mechanism
|
||||||
var deletedIDs []types.NodeID
|
var (
|
||||||
var deleteMutex sync.Mutex
|
deletedIDs []types.NodeID
|
||||||
var deletionWg sync.WaitGroup
|
deleteMutex sync.Mutex
|
||||||
|
deletionWg sync.WaitGroup
|
||||||
|
)
|
||||||
|
|
||||||
deleteFunc := func(nodeID types.NodeID) {
|
deleteFunc := func(nodeID types.NodeID) {
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
deletedIDs = append(deletedIDs, nodeID)
|
deletedIDs = append(deletedIDs, nodeID)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
deletionWg.Done()
|
deletionWg.Done()
|
||||||
}
|
}
|
||||||
@@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||||||
go gc.Start()
|
go gc.Start()
|
||||||
|
|
||||||
// Schedule several nodes for deletion with short expiry
|
// Schedule several nodes for deletion with short expiry
|
||||||
const expiry = fifty
|
const (
|
||||||
const numNodes = 100
|
expiry = fifty
|
||||||
|
numNodes = 100
|
||||||
|
)
|
||||||
|
|
||||||
// Set up wait group for expected deletions
|
// Set up wait group for expected deletions
|
||||||
|
|
||||||
deletionWg.Add(numNodes)
|
deletionWg.Add(numNodes)
|
||||||
|
|
||||||
for i := 1; i <= numNodes; i++ {
|
for i := 1; i <= numNodes; i++ {
|
||||||
gc.Schedule(types.NodeID(i), expiry)
|
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all scheduled deletions to complete
|
// Wait for all scheduled deletions to complete
|
||||||
@@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||||||
|
|
||||||
// Schedule and immediately cancel to test that part of the code
|
// Schedule and immediately cancel to test that part of the code
|
||||||
for i := numNodes + 1; i <= numNodes*2; i++ {
|
for i := numNodes + 1; i <= numNodes*2; i++ {
|
||||||
nodeID := types.NodeID(i)
|
nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test
|
||||||
gc.Schedule(nodeID, time.Hour)
|
gc.Schedule(nodeID, time.Hour)
|
||||||
gc.Cancel(nodeID)
|
gc.Cancel(nodeID)
|
||||||
}
|
}
|
||||||
@@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||||||
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
||||||
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||||
// Deletion tracking mechanism
|
// Deletion tracking mechanism
|
||||||
var deletedIDs []types.NodeID
|
var (
|
||||||
var deleteMutex sync.Mutex
|
deletedIDs []types.NodeID
|
||||||
|
deleteMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
deletionNotifier := make(chan types.NodeID, 1)
|
deletionNotifier := make(chan types.NodeID, 1)
|
||||||
|
|
||||||
deleteFunc := func(nodeID types.NodeID) {
|
deleteFunc := func(nodeID types.NodeID) {
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
deletedIDs = append(deletedIDs, nodeID)
|
deletedIDs = append(deletedIDs, nodeID)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
|
|
||||||
deletionNotifier <- nodeID
|
deletionNotifier <- nodeID
|
||||||
@@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
|||||||
|
|
||||||
// Start GC
|
// Start GC
|
||||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||||
|
|
||||||
go gc.Start()
|
go gc.Start()
|
||||||
defer gc.Close()
|
defer gc.Close()
|
||||||
|
|
||||||
const shortExpiry = fifty
|
const (
|
||||||
const longExpiry = 1 * time.Hour
|
shortExpiry = fifty
|
||||||
|
longExpiry = 1 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
@@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
|||||||
// and verifies that the node is deleted only once.
|
// and verifies that the node is deleted only once.
|
||||||
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||||
// Deletion tracking mechanism
|
// Deletion tracking mechanism
|
||||||
var deletedIDs []types.NodeID
|
var (
|
||||||
var deleteMutex sync.Mutex
|
deletedIDs []types.NodeID
|
||||||
|
deleteMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
deletionNotifier := make(chan types.NodeID, 1)
|
deletionNotifier := make(chan types.NodeID, 1)
|
||||||
|
|
||||||
deleteFunc := func(nodeID types.NodeID) {
|
deleteFunc := func(nodeID types.NodeID) {
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
deletedIDs = append(deletedIDs, nodeID)
|
deletedIDs = append(deletedIDs, nodeID)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
|
|
||||||
deletionNotifier <- nodeID
|
deletionNotifier <- nodeID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the GC
|
// Start the GC
|
||||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||||
|
|
||||||
go gc.Start()
|
go gc.Start()
|
||||||
defer gc.Close()
|
defer gc.Close()
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
const expiry = fifty
|
const expiry = fifty
|
||||||
|
|
||||||
// Schedule node for deletion
|
// Schedule node for deletion
|
||||||
@@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
|||||||
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
||||||
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
||||||
// Deletion tracking
|
// Deletion tracking
|
||||||
var deletedIDs []types.NodeID
|
var (
|
||||||
var deleteMutex sync.Mutex
|
deletedIDs []types.NodeID
|
||||||
|
deleteMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
deletionNotifier := make(chan types.NodeID, 1)
|
deletionNotifier := make(chan types.NodeID, 1)
|
||||||
|
|
||||||
deleteFunc := func(nodeID types.NodeID) {
|
deleteFunc := func(nodeID types.NodeID) {
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
deletedIDs = append(deletedIDs, nodeID)
|
deletedIDs = append(deletedIDs, nodeID)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
|
|
||||||
deletionNotifier <- nodeID
|
deletionNotifier <- nodeID
|
||||||
@@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||||
|
|
||||||
// Deletion tracking
|
// Deletion tracking
|
||||||
var deletedIDs []types.NodeID
|
var (
|
||||||
var deleteMutex sync.Mutex
|
deletedIDs []types.NodeID
|
||||||
|
deleteMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
nodeDeleted := make(chan struct{})
|
nodeDeleted := make(chan struct{})
|
||||||
|
|
||||||
deleteFunc := func(nodeID types.NodeID) {
|
deleteFunc := func(nodeID types.NodeID) {
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
deletedIDs = append(deletedIDs, nodeID)
|
deletedIDs = append(deletedIDs, nodeID)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
close(nodeDeleted) // Signal that deletion happened
|
close(nodeDeleted) // Signal that deletion happened
|
||||||
}
|
}
|
||||||
@@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||||||
// Use a WaitGroup to ensure the GC has started
|
// Use a WaitGroup to ensure the GC has started
|
||||||
var startWg sync.WaitGroup
|
var startWg sync.WaitGroup
|
||||||
startWg.Add(1)
|
startWg.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
startWg.Done() // Signal that the goroutine has started
|
startWg.Done() // Signal that the goroutine has started
|
||||||
gc.Start()
|
gc.Start()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
startWg.Wait() // Wait for the GC to start
|
startWg.Wait() // Wait for the GC to start
|
||||||
|
|
||||||
// Close GC right away
|
// Close GC right away
|
||||||
@@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||||||
|
|
||||||
// Check no node was deleted
|
// Check no node was deleted
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
nodesDeleted := len(deletedIDs)
|
nodesDeleted := len(deletedIDs)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
||||||
|
|
||||||
@@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||||
|
|
||||||
// Deletion tracking mechanism
|
// Deletion tracking mechanism
|
||||||
var deletedIDs []types.NodeID
|
var (
|
||||||
var deleteMutex sync.Mutex
|
deletedIDs []types.NodeID
|
||||||
|
deleteMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
deleteFunc := func(nodeID types.NodeID) {
|
deleteFunc := func(nodeID types.NodeID) {
|
||||||
deleteMutex.Lock()
|
deleteMutex.Lock()
|
||||||
|
|
||||||
deletedIDs = append(deletedIDs, nodeID)
|
deletedIDs = append(deletedIDs, nodeID)
|
||||||
|
|
||||||
deleteMutex.Unlock()
|
deleteMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||||||
go gc.Start()
|
go gc.Start()
|
||||||
|
|
||||||
// Number of concurrent scheduling goroutines
|
// Number of concurrent scheduling goroutines
|
||||||
const numSchedulers = 10
|
const (
|
||||||
const nodesPerScheduler = 50
|
numSchedulers = 10
|
||||||
|
nodesPerScheduler = 50
|
||||||
|
)
|
||||||
|
|
||||||
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
||||||
|
|
||||||
@@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||||||
case <-stopScheduling:
|
case <-stopScheduling:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
nodeID := types.NodeID(baseNodeID + j + 1)
|
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test
|
||||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||||
atomic.AddInt64(&scheduledCount, 1)
|
atomic.AddInt64(&scheduledCount, 1)
|
||||||
|
|
||||||
// Yield to other goroutines to introduce variability
|
// Yield to other goroutines to introduce variability
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ import (
|
|||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
var (
|
||||||
|
errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||||
|
errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix")
|
||||||
|
errIPAllocatorNil = errors.New("ip allocator was nil")
|
||||||
|
)
|
||||||
|
|
||||||
// IPAllocator is a singleton responsible for allocating
|
// IPAllocator is a singleton responsible for allocating
|
||||||
// IP addresses for nodes and making sure the same
|
// IP addresses for nodes and making sure the same
|
||||||
@@ -62,8 +66,10 @@ func NewIPAllocator(
|
|||||||
strategy: strategy,
|
strategy: strategy,
|
||||||
}
|
}
|
||||||
|
|
||||||
var v4s []sql.NullString
|
var (
|
||||||
var v6s []sql.NullString
|
v4s []sql.NullString
|
||||||
|
v6s []sql.NullString
|
||||||
|
)
|
||||||
|
|
||||||
if db != nil {
|
if db != nil {
|
||||||
err := db.Read(func(rx *gorm.DB) error {
|
err := db.Read(func(rx *gorm.DB) error {
|
||||||
@@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
|||||||
i.mu.Lock()
|
i.mu.Lock()
|
||||||
defer i.mu.Unlock()
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
var err error
|
var (
|
||||||
var ret4 *netip.Addr
|
err error
|
||||||
var ret6 *netip.Addr
|
ret4 *netip.Addr
|
||||||
|
ret6 *netip.Addr
|
||||||
|
)
|
||||||
|
|
||||||
if i.prefix4 != nil {
|
if i.prefix4 != nil {
|
||||||
ret4, err = i.next(i.prev4, i.prefix4)
|
ret4, err = i.next(i.prev4, i.prefix4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.prev4 = *ret4
|
i.prev4 = *ret4
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.prev6 = *ret6
|
i.prev6 = *ret6
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
||||||
var err error
|
var (
|
||||||
var ip netip.Addr
|
err error
|
||||||
|
ip netip.Addr
|
||||||
|
)
|
||||||
|
|
||||||
switch i.strategy {
|
switch i.strategy {
|
||||||
case types.IPAllocationStrategySequential:
|
case types.IPAllocationStrategySequential:
|
||||||
@@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
|||||||
|
|
||||||
if !pfx.Contains(ip) {
|
if !pfx.Contains(ip) {
|
||||||
return netip.Addr{}, fmt.Errorf(
|
return netip.Addr{}, fmt.Errorf(
|
||||||
"generated ip(%s) not in prefix(%s)",
|
"%w: ip(%s) not in prefix(%s)",
|
||||||
|
errGeneratedIPNotInPrefix,
|
||||||
ip.String(),
|
ip.String(),
|
||||||
pfx.String(),
|
pfx.String(),
|
||||||
)
|
)
|
||||||
@@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool {
|
|||||||
// If a prefix type has been removed (IPv4 or IPv6), it
|
// If a prefix type has been removed (IPv4 or IPv6), it
|
||||||
// will remove the IPs in that family from the node.
|
// will remove the IPs in that family from the node.
|
||||||
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||||
var err error
|
var (
|
||||||
var ret []string
|
err error
|
||||||
|
ret []string
|
||||||
|
)
|
||||||
|
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
if i == nil {
|
if i == nil {
|
||||||
return errors.New("backfilling IPs: ip allocator was nil")
|
return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Caller().Msgf("starting to backfill IPs")
|
log.Trace().Caller().Msgf("starting to backfill IPs")
|
||||||
@@ -283,18 +299,19 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database")
|
log.Trace().Caller().EmbedObject(node).Msg("ip backfill check started because node found in database")
|
||||||
|
|
||||||
changed := false
|
changed := false
|
||||||
// IPv4 prefix is set, but node ip is missing, alloc
|
// IPv4 prefix is set, but node ip is missing, alloc
|
||||||
if i.prefix4 != nil && node.IPv4 == nil {
|
if i.prefix4 != nil && node.IPv4 == nil {
|
||||||
ret4, err := i.nextLocked(i.prev4, i.prefix4)
|
ret4, err := i.nextLocked(i.prev4, i.prefix4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err)
|
return fmt.Errorf("allocating IPv4 for node(%d): %w", node.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.IPv4 = ret4
|
node.IPv4 = ret4
|
||||||
changed = true
|
changed = true
|
||||||
|
|
||||||
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,11 +319,12 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
|||||||
if i.prefix6 != nil && node.IPv6 == nil {
|
if i.prefix6 != nil && node.IPv6 == nil {
|
||||||
ret6, err := i.nextLocked(i.prev6, i.prefix6)
|
ret6, err := i.nextLocked(i.prev6, i.prefix6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err)
|
return fmt.Errorf("allocating IPv6 for node(%d): %w", node.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.IPv6 = ret6
|
node.IPv6 = ret6
|
||||||
changed = true
|
changed = true
|
||||||
|
|
||||||
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var mpp = func(pref string) *netip.Prefix {
|
var mpp = func(pref string) *netip.Prefix {
|
||||||
@@ -21,9 +20,7 @@ var mpp = func(pref string) *netip.Prefix {
|
|||||||
return &p
|
return &p
|
||||||
}
|
}
|
||||||
|
|
||||||
var na = func(pref string) netip.Addr {
|
var na = netip.MustParseAddr
|
||||||
return netip.MustParseAddr(pref)
|
|
||||||
}
|
|
||||||
|
|
||||||
var nap = func(pref string) *netip.Addr {
|
var nap = func(pref string) *netip.Addr {
|
||||||
n := na(pref)
|
n := na(pref)
|
||||||
@@ -158,8 +155,10 @@ func TestIPAllocatorSequential(t *testing.T) {
|
|||||||
types.IPAllocationStrategySequential,
|
types.IPAllocationStrategySequential,
|
||||||
)
|
)
|
||||||
|
|
||||||
var got4s []netip.Addr
|
var (
|
||||||
var got6s []netip.Addr
|
got4s []netip.Addr
|
||||||
|
got6s []netip.Addr
|
||||||
|
)
|
||||||
|
|
||||||
for range tt.getCount {
|
for range tt.getCount {
|
||||||
got4, got6, err := alloc.Next()
|
got4, got6, err := alloc.Next()
|
||||||
@@ -175,6 +174,7 @@ func TestIPAllocatorSequential(t *testing.T) {
|
|||||||
got6s = append(got6s, *got6)
|
got6s = append(got6s, *got6)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
||||||
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -288,6 +288,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||||||
fullNodeP := func(i int) *types.Node {
|
fullNodeP := func(i int) *types.Node {
|
||||||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||||
|
|
||||||
return &types.Node{
|
return &types.Node{
|
||||||
IPv4: nap(v4),
|
IPv4: nap(v4),
|
||||||
IPv6: nap(v6),
|
IPv6: nap(v6),
|
||||||
@@ -484,12 +485,13 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||||
db, err := newSQLiteTestDB()
|
db, err := newSQLiteTestDB()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
alloc, err := NewIPAllocator(
|
alloc, err := NewIPAllocator(
|
||||||
db,
|
db,
|
||||||
ptr.To(tsaddr.CGNATRange()),
|
new(tsaddr.CGNATRange()),
|
||||||
ptr.To(tsaddr.TailscaleULARange()),
|
new(tsaddr.TailscaleULARange()),
|
||||||
types.IPAllocationStrategySequential,
|
types.IPAllocationStrategySequential,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -497,17 +499,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate that we do not give out 100.100.100.100
|
// Validate that we do not give out 100.100.100.100
|
||||||
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
|
nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
||||||
|
|
||||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
|
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
||||||
|
|
||||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||||
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
|
nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange()))
|
||||||
t.Logf("chrome: %s", nextChrome.String())
|
t.Logf("chrome: %s", nextChrome.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
||||||
|
|||||||
@@ -16,18 +16,24 @@ import (
|
|||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NodeGivenNameHashLength = 8
|
NodeGivenNameHashLength = 8
|
||||||
NodeGivenNameTrimSize = 2
|
NodeGivenNameTrimSize = 2
|
||||||
|
|
||||||
|
// defaultTestNodePrefix is the default hostname prefix for nodes created in tests.
|
||||||
|
defaultTestNodePrefix = "testnode"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||||
|
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||||
|
|
||||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -51,12 +57,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID)
|
|||||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||||
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
|
||||||
|
err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Where("id <> ?", nodeID).
|
Where("id <> ?", nodeID).
|
||||||
Where(peerIDs).Find(&nodes).Error; err != nil {
|
Where(peerIDs).Find(&nodes).Error
|
||||||
|
if err != nil {
|
||||||
return types.Nodes{}, err
|
return types.Nodes{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,11 +83,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error)
|
|||||||
// or for the given nodes if at least one node ID is given as parameter.
|
// or for the given nodes if at least one node ID is given as parameter.
|
||||||
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
|
||||||
|
err := tx.
|
||||||
Preload("AuthKey").
|
Preload("AuthKey").
|
||||||
Preload("AuthKey.User").
|
Preload("AuthKey.User").
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
Where(nodeIDs).Find(&nodes).Error
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,7 +99,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
|||||||
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
|
|
||||||
|
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,6 +219,7 @@ func SetTags(
|
|||||||
|
|
||||||
slices.Sort(tags)
|
slices.Sort(tags)
|
||||||
tags = slices.Compact(tags)
|
tags = slices.Compact(tags)
|
||||||
|
|
||||||
b, err := json.Marshal(tags)
|
b, err := json.Marshal(tags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -220,7 +233,7 @@ func SetTags(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTags takes a Node struct pointer and update the forced tags.
|
// SetApprovedRoutes takes a Node struct pointer and updates the approved routes.
|
||||||
func SetApprovedRoutes(
|
func SetApprovedRoutes(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
nodeID types.NodeID,
|
nodeID types.NodeID,
|
||||||
@@ -228,7 +241,8 @@ func SetApprovedRoutes(
|
|||||||
) error {
|
) error {
|
||||||
if len(routes) == 0 {
|
if len(routes) == 0 {
|
||||||
// if no routes are provided, we remove all
|
// if no routes are provided, we remove all
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
|
err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("removing approved routes: %w", err)
|
return fmt.Errorf("removing approved routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +265,7 @@ func SetApprovedRoutes(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("updating approved routes: %w", err)
|
return fmt.Errorf("updating approved routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,22 +291,25 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
|||||||
func RenameNode(tx *gorm.DB,
|
func RenameNode(tx *gorm.DB,
|
||||||
nodeID types.NodeID, newName string,
|
nodeID types.NodeID, newName string,
|
||||||
) error {
|
) error {
|
||||||
if err := util.ValidateHostname(newName); err != nil {
|
err := util.ValidateHostname(newName)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("renaming node: %w", err)
|
return fmt.Errorf("renaming node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the new name is unique
|
// Check if the new name is unique
|
||||||
var count int64
|
var count int64
|
||||||
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
|
||||||
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("checking name uniqueness: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
return errors.New("name is not unique")
|
return ErrNodeNameNotUnique
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr
|
||||||
return fmt.Errorf("failed to rename node in the database: %w", err)
|
return fmt.Errorf("renaming node in database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -323,7 +340,8 @@ func DeleteNode(tx *gorm.DB,
|
|||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) error {
|
||||||
// Unscoped causes the node to be fully removed from the database.
|
// Unscoped causes the node to be fully removed from the database.
|
||||||
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,9 +355,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
|||||||
nodeID types.NodeID,
|
nodeID types.NodeID,
|
||||||
) error {
|
) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
|
err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -352,19 +372,19 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||||||
}
|
}
|
||||||
|
|
||||||
logEvent := log.Debug().
|
logEvent := log.Debug().
|
||||||
Str("node", node.Hostname).
|
Str(zf.NodeHostname, node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str(zf.MachineKey, node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString())
|
Str(zf.NodeKey, node.NodeKey.ShortString())
|
||||||
|
|
||||||
if node.User != nil {
|
if node.User != nil {
|
||||||
logEvent = logEvent.Str("user", node.User.Username())
|
logEvent = logEvent.Str(zf.UserName, node.User.Username())
|
||||||
} else if node.UserID != nil {
|
} else if node.UserID != nil {
|
||||||
logEvent = logEvent.Uint("user_id", *node.UserID)
|
logEvent = logEvent.Uint(zf.UserID, *node.UserID)
|
||||||
} else {
|
} else {
|
||||||
logEvent = logEvent.Str("user", "none")
|
logEvent = logEvent.Str(zf.UserName, "none")
|
||||||
}
|
}
|
||||||
|
|
||||||
logEvent.Msg("Registering test node")
|
logEvent.Msg("registering test node")
|
||||||
|
|
||||||
// If the a new node is registered with the same machine key, to the same user,
|
// If the a new node is registered with the same machine key, to the same user,
|
||||||
// update the existing node.
|
// update the existing node.
|
||||||
@@ -379,6 +399,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||||||
if ipv4 == nil {
|
if ipv4 == nil {
|
||||||
ipv4 = oldNode.IPv4
|
ipv4 = oldNode.IPv4
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipv6 == nil {
|
if ipv6 == nil {
|
||||||
ipv6 = oldNode.IPv6
|
ipv6 = oldNode.IPv6
|
||||||
}
|
}
|
||||||
@@ -388,16 +409,17 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||||||
// so we store the node.Expire and node.Nodekey that has been set when
|
// so we store the node.Expire and node.Nodekey that has been set when
|
||||||
// adding it to the registrationCache
|
// adding it to the registrationCache
|
||||||
if node.IPv4 != nil || node.IPv6 != nil {
|
if node.IPv4 != nil || node.IPv6 != nil {
|
||||||
if err := tx.Save(&node).Error; err != nil {
|
err := tx.Save(&node).Error
|
||||||
return nil, fmt.Errorf("failed register existing node in the database: %w", err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("registering existing node in database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname).
|
Str(zf.NodeHostname, node.Hostname).
|
||||||
Str("machine_key", node.MachineKey.ShortString()).
|
Str(zf.MachineKey, node.MachineKey.ShortString()).
|
||||||
Str("node_key", node.NodeKey.ShortString()).
|
Str(zf.NodeKey, node.NodeKey.ShortString()).
|
||||||
Str("user", node.User.Username()).
|
Str(zf.UserName, node.User.Username()).
|
||||||
Msg("Test node authorized again")
|
Msg("Test node authorized again")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
@@ -407,29 +429,30 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||||||
node.IPv6 = ipv6
|
node.IPv6 = ipv6
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
node.Hostname, err = util.NormaliseHostname(node.Hostname)
|
node.Hostname, err = util.NormaliseHostname(node.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newHostname := util.InvalidString()
|
newHostname := util.InvalidString()
|
||||||
log.Info().Err(err).Str("invalid-hostname", node.Hostname).Str("new-hostname", newHostname).Msgf("Invalid hostname, replacing")
|
log.Info().Err(err).Str(zf.InvalidHostname, node.Hostname).Str(zf.NewHostname, newHostname).Msgf("invalid hostname, replacing")
|
||||||
node.Hostname = newHostname
|
node.Hostname = newHostname
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.GivenName == "" {
|
if node.GivenName == "" {
|
||||||
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
return nil, fmt.Errorf("ensuring unique given name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.GivenName = givenName
|
node.GivenName = givenName
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Save(&node).Error; err != nil {
|
if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr
|
||||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
return nil, fmt.Errorf("saving node to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname).
|
Str(zf.NodeHostname, node.Hostname).
|
||||||
Msg("Test node registered with the database")
|
Msg("Test node registered with the database")
|
||||||
|
|
||||||
return &node, nil
|
return &node, nil
|
||||||
@@ -491,8 +514,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
|||||||
|
|
||||||
func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.
|
|
||||||
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
|
err := tx.
|
||||||
|
Where("given_name = ?", name).Find(&nodes).Error
|
||||||
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -646,7 +671,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
|||||||
panic("CreateNodeForTest requires a valid user")
|
panic("CreateNodeForTest requires a valid user")
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeName := "testnode"
|
nodeName := defaultTestNodePrefix
|
||||||
if len(hostname) > 0 && hostname[0] != "" {
|
if len(hostname) > 0 && hostname[0] != "" {
|
||||||
nodeName = hostname[0]
|
nodeName = hostname[0]
|
||||||
}
|
}
|
||||||
@@ -657,6 +682,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
|||||||
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
|
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pakID := pak.ID
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
discoKey := key.NewDisco()
|
discoKey := key.NewDisco()
|
||||||
@@ -668,7 +694,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
|||||||
Hostname: nodeName,
|
Hostname: nodeName,
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
AuthKeyID: &pakID,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = hsdb.DB.Save(node).Error
|
err = hsdb.DB.Save(node).Error
|
||||||
@@ -694,9 +720,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
|||||||
}
|
}
|
||||||
|
|
||||||
var registeredNode *types.Node
|
var registeredNode *types.Node
|
||||||
|
|
||||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -715,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname
|
|||||||
panic("CreateNodesForTest requires a valid user")
|
panic("CreateNodesForTest requires a valid user")
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := "testnode"
|
prefix := defaultTestNodePrefix
|
||||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||||
prefix = hostnamePrefix[0]
|
prefix = hostnamePrefix[0]
|
||||||
}
|
}
|
||||||
@@ -738,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
|||||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := "testnode"
|
prefix := defaultTestNodePrefix
|
||||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||||
prefix = hostnamePrefix[0]
|
prefix = hostnamePrefix[0]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetNode(t *testing.T) {
|
func TestGetNode(t *testing.T) {
|
||||||
@@ -102,6 +101,8 @@ func TestExpireNode(t *testing.T) {
|
|||||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pakID := pak.ID
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
@@ -115,7 +116,7 @@ func TestExpireNode(t *testing.T) {
|
|||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
AuthKeyID: &pakID,
|
||||||
Expiry: &time.Time{},
|
Expiry: &time.Time{},
|
||||||
}
|
}
|
||||||
db.DB.Save(node)
|
db.DB.Save(node)
|
||||||
@@ -146,6 +147,8 @@ func TestSetTags(t *testing.T) {
|
|||||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pakID := pak.ID
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
@@ -159,7 +162,7 @@ func TestSetTags(t *testing.T) {
|
|||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
AuthKeyID: &pakID,
|
||||||
}
|
}
|
||||||
|
|
||||||
trx := db.DB.Save(node)
|
trx := db.DB.Save(node)
|
||||||
@@ -187,6 +190,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
|||||||
suppliedName string
|
suppliedName string
|
||||||
randomSuffix bool
|
randomSuffix bool
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
@@ -443,7 +447,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tt.routes,
|
RoutableIPs: tt.routes,
|
||||||
},
|
},
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = adb.DB.Save(&node).Error
|
err = adb.DB.Save(&node).Error
|
||||||
@@ -460,17 +464,17 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
RoutableIPs: tt.routes,
|
RoutableIPs: tt.routes,
|
||||||
},
|
},
|
||||||
Tags: []string{"tag:exit"},
|
Tags: []string{"tag:exit"},
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
IPv4: new(netip.MustParseAddr("100.64.0.2")),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = adb.DB.Save(&nodeTagged).Error
|
err = adb.DB.Save(&nodeTagged).Error
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
users, err := adb.ListUsers()
|
users, err := adb.ListUsers()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes, err := adb.ListNodes()
|
nodes, err := adb.ListNodes()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
pm, err := pmf(users, nodes.ViewSlice())
|
pm, err := pmf(users, nodes.ViewSlice())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -498,6 +502,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
if len(expectedRoutes1) == 0 {
|
if len(expectedRoutes1) == 0 {
|
||||||
expectedRoutes1 = nil
|
expectedRoutes1 = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -509,6 +514,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
if len(expectedRoutes2) == 0 {
|
if len(expectedRoutes2) == 0 {
|
||||||
expectedRoutes2 = nil
|
expectedRoutes2 = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -520,6 +526,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||||
want := []types.NodeID{1, 3}
|
want := []types.NodeID{1, 3}
|
||||||
got := []types.NodeID{}
|
got := []types.NodeID{}
|
||||||
|
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
|
|
||||||
deletionCount := make(chan struct{}, 10)
|
deletionCount := make(chan struct{}, 10)
|
||||||
@@ -527,6 +534,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
|||||||
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
got = append(got, ni)
|
got = append(got, ni)
|
||||||
|
|
||||||
deletionCount <- struct{}{}
|
deletionCount <- struct{}{}
|
||||||
@@ -576,8 +584,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||||
var got []types.NodeID
|
var (
|
||||||
var mu sync.Mutex
|
got []types.NodeID
|
||||||
|
mu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
want := 1000
|
want := 1000
|
||||||
|
|
||||||
@@ -589,6 +599,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
|||||||
|
|
||||||
// Yield to other goroutines to introduce variability
|
// Yield to other goroutines to introduce variability
|
||||||
runtime.Gosched()
|
runtime.Gosched()
|
||||||
|
|
||||||
got = append(got, ni)
|
got = append(got, ni)
|
||||||
|
|
||||||
atomic.AddInt64(&deletedCount, 1)
|
atomic.AddInt64(&deletedCount, 1)
|
||||||
@@ -616,9 +627,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateRandomNumber(t *testing.T, max int64) int64 {
|
//nolint:unused
|
||||||
|
func generateRandomNumber(t *testing.T, maxVal int64) int64 {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
maxB := big.NewInt(max)
|
|
||||||
|
maxB := big.NewInt(maxVal)
|
||||||
|
|
||||||
n, err := rand.Int(rand.Reader, maxB)
|
n, err := rand.Int(rand.Reader, maxB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getting random number: %s", err)
|
t.Fatalf("getting random number: %s", err)
|
||||||
@@ -642,6 +656,9 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||||||
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
|
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pakID := pak.ID
|
||||||
|
pakEphID := pakEph.ID
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
@@ -649,7 +666,7 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||||||
Hostname: "test",
|
Hostname: "test",
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
AuthKeyID: &pakID,
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeEph := types.Node{
|
nodeEph := types.Node{
|
||||||
@@ -659,7 +676,7 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||||||
Hostname: "ephemeral",
|
Hostname: "ephemeral",
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(pakEph.ID),
|
AuthKeyID: &pakEphID,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.DB.Save(&node).Error
|
err = db.DB.Save(&node).Error
|
||||||
@@ -722,7 +739,7 @@ func TestNodeNaming(t *testing.T) {
|
|||||||
nodeInvalidHostname := types.Node{
|
nodeInvalidHostname := types.Node{
|
||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: "我的电脑",
|
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||||
UserID: &user2.ID,
|
UserID: &user2.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
}
|
}
|
||||||
@@ -746,12 +763,15 @@ func TestNodeNaming(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
|
||||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
|
_, _ = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
|
||||||
|
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -810,25 +830,25 @@ func TestNodeNaming(t *testing.T) {
|
|||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
return RenameNode(tx, nodes[0].ID, "test")
|
return RenameNode(tx, nodes[0].ID, "test")
|
||||||
})
|
})
|
||||||
assert.ErrorContains(t, err, "name is not unique")
|
require.ErrorContains(t, err, "name is not unique")
|
||||||
|
|
||||||
// Rename invalid chars
|
// Rename invalid chars
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
return RenameNode(tx, nodes[2].ID, "我的电脑")
|
return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data
|
||||||
})
|
})
|
||||||
assert.ErrorContains(t, err, "invalid characters")
|
require.ErrorContains(t, err, "invalid characters")
|
||||||
|
|
||||||
// Rename too short
|
// Rename too short
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
return RenameNode(tx, nodes[3].ID, "a")
|
return RenameNode(tx, nodes[3].ID, "a")
|
||||||
})
|
})
|
||||||
assert.ErrorContains(t, err, "at least 2 characters")
|
require.ErrorContains(t, err, "at least 2 characters")
|
||||||
|
|
||||||
// Rename with emoji
|
// Rename with emoji
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
|
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
|
||||||
})
|
})
|
||||||
assert.ErrorContains(t, err, "invalid characters")
|
require.ErrorContains(t, err, "invalid characters")
|
||||||
|
|
||||||
// Rename with only emoji
|
// Rename with only emoji
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
@@ -896,12 +916,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "chinese_chars_with_dash_rejected",
|
name: "chinese_chars_with_dash_rejected",
|
||||||
newName: "server-北京-01",
|
newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
|
||||||
wantErr: "invalid characters",
|
wantErr: "invalid characters",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "chinese_only_rejected",
|
name: "chinese_only_rejected",
|
||||||
newName: "我的电脑",
|
newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||||
wantErr: "invalid characters",
|
wantErr: "invalid characters",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -911,7 +931,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed_chinese_emoji_rejected",
|
name: "mixed_chinese_emoji_rejected",
|
||||||
newName: "测试💻机器",
|
newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
|
||||||
wantErr: "invalid characters",
|
wantErr: "invalid characters",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -1000,6 +1020,7 @@ func TestListPeers(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -1085,6 +1106,7 @@ func TestListNodes(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
|
|||||||
Data: policy,
|
Data: policy,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil {
|
err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -138,8 +138,8 @@ func CreatePreAuthKey(
|
|||||||
Hash: hash, // Store hash
|
Hash: hash, // Store hash
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Save(&key).Error; err != nil {
|
if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||||
return nil, fmt.Errorf("failed to create key in the database: %w", err)
|
return nil, fmt.Errorf("creating key in database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &types.PreAuthKeyNew{
|
return &types.PreAuthKeyNew{
|
||||||
@@ -155,9 +155,7 @@ func CreatePreAuthKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
return Read(hsdb.DB, ListPreAuthKeys)
|
||||||
return ListPreAuthKeys(rx)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns all PreAuthKeys in the database.
|
// ListPreAuthKeys returns all PreAuthKeys in the database.
|
||||||
@@ -296,7 +294,7 @@ func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
|
|||||||
Where("auth_key_id = ?", id).
|
Where("auth_key_id = ?", id).
|
||||||
Update("auth_key_id", nil).Error
|
Update("auth_key_id", nil).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
|
return fmt.Errorf("clearing auth_key_id on nodes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then delete the pre-auth key
|
// Then delete the pre-auth key
|
||||||
@@ -325,14 +323,15 @@ func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
|
|||||||
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||||
err := tx.Model(k).Update("used", true).Error
|
err := tx.Model(k).Update("used", true).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
return fmt.Errorf("updating key used status in database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
k.Used = true
|
k.Used = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
// ExpirePreAuthKey marks a PreAuthKey as expired.
|
||||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCreatePreAuthKey(t *testing.T) {
|
func TestCreatePreAuthKey(t *testing.T) {
|
||||||
@@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) {
|
|||||||
test: func(t *testing.T, db *HSDatabase) {
|
test: func(t *testing.T, db *HSDatabase) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil)
|
_, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
|
|||||||
Hostname: "testest",
|
Hostname: "testest",
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(key.ID),
|
AuthKeyID: new(key.ID),
|
||||||
}
|
}
|
||||||
db.DB.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
|
|||||||
@@ -104,3 +104,9 @@ CREATE TABLE policies(
|
|||||||
deleted_at datetime
|
deleted_at datetime
|
||||||
);
|
);
|
||||||
CREATE INDEX idx_policies_deleted_at ON policies(deleted_at);
|
CREATE INDEX idx_policies_deleted_at ON policies(deleted_at);
|
||||||
|
|
||||||
|
CREATE TABLE database_versions(
|
||||||
|
id integer PRIMARY KEY,
|
||||||
|
version text NOT NULL,
|
||||||
|
updated_at datetime
|
||||||
|
);
|
||||||
|
|||||||
@@ -362,7 +362,8 @@ func (c *Config) Validate() error {
|
|||||||
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
|
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
|
||||||
// compatible with modernc.org/sqlite driver.
|
// compatible with modernc.org/sqlite driver.
|
||||||
func (c *Config) ToURL() (string, error) {
|
func (c *Config) ToURL() (string, error) {
|
||||||
if err := c.Validate(); err != nil {
|
err := c.Validate()
|
||||||
|
if err != nil {
|
||||||
return "", fmt.Errorf("invalid config: %w", err)
|
return "", fmt.Errorf("invalid config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) {
|
|||||||
if c.BusyTimeout > 0 {
|
if c.BusyTimeout > 0 {
|
||||||
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.JournalMode != "" {
|
if c.JournalMode != "" {
|
||||||
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.AutoVacuum != "" {
|
if c.AutoVacuum != "" {
|
||||||
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.WALAutocheckpoint >= 0 {
|
if c.WALAutocheckpoint >= 0 {
|
||||||
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Synchronous != "" {
|
if c.Synchronous != "" {
|
||||||
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.ForeignKeys {
|
if c.ForeignKeys {
|
||||||
pragmas = append(pragmas, "foreign_keys=ON")
|
pragmas = append(pragmas, "foreign_keys=ON")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
|
|||||||
t.Errorf("Config.ToURL() error = %v", err)
|
t.Errorf("Config.ToURL() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
||||||
}
|
}
|
||||||
@@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
|
|||||||
Path: "",
|
Path: "",
|
||||||
BusyTimeout: -1,
|
BusyTimeout: -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := config.ToURL()
|
_, err := config.ToURL()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Config.ToURL() with invalid config should return error")
|
t.Error("Config.ToURL() with invalid config should return error")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sqliteconfig
|
package sqliteconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Test connection
|
// Test connection
|
||||||
if err := db.Ping(); err != nil {
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = db.PingContext(ctx)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to ping database: %v", err)
|
t.Fatalf("Failed to ping database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||||||
for pragma, expectedValue := range tt.expected {
|
for pragma, expectedValue := range tt.expected {
|
||||||
t.Run("pragma_"+pragma, func(t *testing.T) {
|
t.Run("pragma_"+pragma, func(t *testing.T) {
|
||||||
var actualValue any
|
var actualValue any
|
||||||
|
|
||||||
query := "PRAGMA " + pragma
|
query := "PRAGMA " + pragma
|
||||||
err := db.QueryRow(query).Scan(&actualValue)
|
|
||||||
|
err := db.QueryRowContext(ctx, query).Scan(&actualValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to query %s: %v", query, err)
|
t.Fatalf("Failed to query %s: %v", query, err)
|
||||||
}
|
}
|
||||||
@@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
// Create test tables with foreign key relationship
|
// Create test tables with foreign key relationship
|
||||||
schema := `
|
schema := `
|
||||||
CREATE TABLE parent (
|
CREATE TABLE parent (
|
||||||
@@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
if _, err := db.Exec(schema); err != nil {
|
_, err = db.ExecContext(ctx, schema)
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to create schema: %v", err)
|
t.Fatalf("Failed to create schema: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert parent record
|
// Insert parent record
|
||||||
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
_, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')")
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("Failed to insert parent: %v", err)
|
t.Fatalf("Failed to insert parent: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 1: Valid foreign key should work
|
// Test 1: Valid foreign key should work
|
||||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Valid foreign key insert failed: %v", err)
|
t.Fatalf("Valid foreign key insert failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 2: Invalid foreign key should fail
|
// Test 2: Invalid foreign key should fail
|
||||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
||||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||||
@@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test 3: Deleting referenced parent should fail
|
// Test 3: Deleting referenced parent should fail
|
||||||
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
|
_, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
||||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||||
@@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) {
|
|||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
var actualMode string
|
var actualMode string
|
||||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
|
||||||
|
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
srv, err := postgrestest.Start(ctx)
|
srv, err := postgrestest.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(srv.Cleanup)
|
t.Cleanup(srv.Cleanup)
|
||||||
|
|
||||||
u, err := srv.CreateDatabase(ctx)
|
u, err := srv.CreateDatabase(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("created local postgres: %s", u)
|
t.Logf("created local postgres: %s", u)
|
||||||
pu, _ := url.Parse(u)
|
pu, _ := url.Parse(u)
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,19 @@ package db
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding"
|
"encoding"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errUnmarshalTextValue = errors.New("unmarshalling text value")
|
||||||
|
errUnsupportedType = errors.New("unsupported type")
|
||||||
|
errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
|
||||||
|
)
|
||||||
|
|
||||||
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||||
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
|
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
|
||||||
|
|
||||||
@@ -24,7 +31,7 @@ func maybeInstantiatePtr(rv reflect.Value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func decodingError(name string, err error) error {
|
func decodingError(name string, err error) error {
|
||||||
return fmt.Errorf("error decoding to %s: %w", name, err)
|
return fmt.Errorf("decoding to %s: %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextSerialiser implements the Serialiser interface for fields that
|
// TextSerialiser implements the Serialiser interface for fields that
|
||||||
@@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||||||
|
|
||||||
if dbValue != nil {
|
if dbValue != nil {
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
|
|
||||||
switch v := dbValue.(type) {
|
switch v := dbValue.(type) {
|
||||||
case []byte:
|
case []byte:
|
||||||
bytes = v
|
bytes = v
|
||||||
case string:
|
case string:
|
||||||
bytes = []byte(v)
|
bytes = []byte(v)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
|
return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isTextUnmarshaler(fieldValue) {
|
if isTextUnmarshaler(fieldValue) {
|
||||||
maybeInstantiatePtr(fieldValue)
|
maybeInstantiatePtr(fieldValue)
|
||||||
f := fieldValue.MethodByName("UnmarshalText")
|
f := fieldValue.MethodByName("UnmarshalText")
|
||||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||||
|
|
||||||
ret := f.Call(args)
|
ret := f.Call(args)
|
||||||
if !ret[0].IsNil() {
|
if !ret[0].IsNil() {
|
||||||
return decodingError(field.Name, ret[0].Interface().(error))
|
if err, ok := ret[0].Interface().(error); ok {
|
||||||
|
return decodingError(field.Name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the underlying field is to a pointer type, we need to
|
// If the underlying field is to a pointer type, we need to
|
||||||
@@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
|||||||
// always comparable, particularly when reflection is involved:
|
// always comparable, particularly when reflection is involved:
|
||||||
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||||
return nil, nil
|
return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := v.MarshalText()
|
b, err := v.MarshalText()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
|||||||
|
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,9 +12,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserExists = errors.New("user already exists")
|
ErrUserExists = errors.New("user already exists")
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||||
|
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
|
||||||
|
ErrUserNotUnique = errors.New("expected exactly one user")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||||
@@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
|||||||
// CreateUser creates a new User. Returns error if could not be created
|
// CreateUser creates a new User. Returns error if could not be created
|
||||||
// or another user already exists.
|
// or another user already exists.
|
||||||
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
||||||
if err := util.ValidateHostname(user.Name); err != nil {
|
err := util.ValidateHostname(user.Name)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := tx.Create(&user).Error; err != nil {
|
|
||||||
|
err = tx.Create(&user).Error
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating user: %w", err)
|
return nil, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
return ErrUserStillHasNodes
|
return ErrUserStillHasNodes
|
||||||
}
|
}
|
||||||
@@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
err = DestroyPreAuthKey(tx, key.ID)
|
err = DestroyPreAuthKey(tx, key.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
|||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
oldUser, err := GetUserByID(tx, uid)
|
oldUser, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = util.ValidateHostname(newName); err != nil {
|
|
||||||
|
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
|||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||||
if len(where) > 1 {
|
if len(where) > 1 {
|
||||||
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
|
||||||
}
|
}
|
||||||
|
|
||||||
var user *types.User
|
var user *types.User
|
||||||
@@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
users := []types.User{}
|
users := []types.User{}
|
||||||
if err := tx.Where(user).Find(&users).Error; err != nil {
|
|
||||||
|
err := tx.Where(user).Find(&users).Error
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &users[0], nil
|
return &users[0], nil
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateAndDestroyUser(t *testing.T) {
|
func TestCreateAndDestroyUser(t *testing.T) {
|
||||||
@@ -74,12 +73,14 @@ func TestDestroyUserErrors(t *testing.T) {
|
|||||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pakID := pak.ID
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: &user.ID,
|
UserID: &user.ID,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
AuthKeyID: &pakID,
|
||||||
}
|
}
|
||||||
trx := db.DB.Save(&node)
|
trx := db.DB.Save(&node)
|
||||||
require.NoError(t, trx.Error)
|
require.NoError(t, trx.Error)
|
||||||
|
|||||||
251
hscontrol/db/versioncheck.go
Normal file
251
hscontrol/db/versioncheck.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errVersionUpgrade = errors.New("version upgrade not supported")
|
||||||
|
|
||||||
|
var errVersionDowngrade = errors.New("version downgrade not supported")
|
||||||
|
|
||||||
|
var errVersionMajorChange = errors.New("major version change not supported")
|
||||||
|
|
||||||
|
var errVersionParse = errors.New("cannot parse version")
|
||||||
|
|
||||||
|
var errVersionFormat = errors.New(
|
||||||
|
"version does not follow semver major.minor.patch format",
|
||||||
|
)
|
||||||
|
|
||||||
|
// DatabaseVersion tracks the headscale version that last
|
||||||
|
// successfully started against this database.
|
||||||
|
// It is a single-row table (ID is always 1).
|
||||||
|
type DatabaseVersion struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
Version string `gorm:"not null"`
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// semver holds parsed major.minor.patch components.
|
||||||
|
type semver struct {
|
||||||
|
Major int
|
||||||
|
Minor int
|
||||||
|
Patch int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s semver) String() string {
|
||||||
|
return fmt.Sprintf("v%d.%d.%d", s.Major, s.Minor, s.Patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseVersion parses a version string like "v0.25.0", "0.25.1",
|
||||||
|
// "v0.25.0-beta.1", or "v0.25.0-rc1+build123" into its major, minor,
|
||||||
|
// patch components. Pre-release and build metadata suffixes are stripped.
|
||||||
|
func parseVersion(s string) (semver, error) {
|
||||||
|
if s == "" || s == "dev" {
|
||||||
|
return semver{}, fmt.Errorf("%q: %w", s, errVersionParse)
|
||||||
|
}
|
||||||
|
|
||||||
|
v := strings.TrimPrefix(s, "v")
|
||||||
|
|
||||||
|
// Strip pre-release suffix (everything after first '-')
|
||||||
|
// and build metadata (everything after first '+').
|
||||||
|
if idx := strings.IndexAny(v, "-+"); idx != -1 {
|
||||||
|
v = v[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(v, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return semver{}, fmt.Errorf("%q: %w", s, errVersionFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
major, err := strconv.Atoi(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return semver{}, fmt.Errorf("invalid major version in %q: %w", s, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
minor, err := strconv.Atoi(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return semver{}, fmt.Errorf("invalid minor version in %q: %w", s, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
patch, err := strconv.Atoi(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return semver{}, fmt.Errorf("invalid patch version in %q: %w", s, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return semver{Major: major, Minor: minor, Patch: patch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureDatabaseVersionTable creates the database_versions table if it
|
||||||
|
// does not already exist. Uses GORM AutoMigrate to handle dialect
|
||||||
|
// differences between SQLite (datetime) and PostgreSQL (timestamp).
|
||||||
|
// This runs before gormigrate migrations.
|
||||||
|
func ensureDatabaseVersionTable(db *gorm.DB) error {
|
||||||
|
err := db.AutoMigrate(&DatabaseVersion{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating database version table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDatabaseVersion reads the stored version from the database.
|
||||||
|
// Returns an empty string if no version has been stored yet.
|
||||||
|
func getDatabaseVersion(db *gorm.DB) (string, error) {
|
||||||
|
var version string
|
||||||
|
|
||||||
|
result := db.Raw("SELECT version FROM database_versions WHERE id = 1").Scan(&version)
|
||||||
|
if result.Error != nil {
|
||||||
|
return "", fmt.Errorf("reading database version: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setDatabaseVersion upserts the version row in the database.
|
||||||
|
func setDatabaseVersion(db *gorm.DB, version string) error {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// Try update first, then insert if no rows affected.
|
||||||
|
result := db.Exec(
|
||||||
|
"UPDATE database_versions SET version = ?, updated_at = ? WHERE id = 1",
|
||||||
|
version, now,
|
||||||
|
)
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("updating database version: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
err := db.Exec(
|
||||||
|
"INSERT INTO database_versions (id, version, updated_at) VALUES (1, ?, ?)",
|
||||||
|
version, now,
|
||||||
|
).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting database version: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDev reports whether a version string represents a development build
|
||||||
|
// that should skip version checking.
|
||||||
|
func isDev(version string) bool {
|
||||||
|
return version == "" || version == "dev" || version == "(devel)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkVersionUpgradePath verifies that the running headscale version
|
||||||
|
// is compatible with the version that last used this database.
|
||||||
|
//
|
||||||
|
// Rules:
|
||||||
|
// - If the running binary has no version ("dev" or empty), warn and skip.
|
||||||
|
// - If no version is stored in the database, allow (first run with this feature).
|
||||||
|
// - If the stored version is "dev", allow (previous run was unversioned).
|
||||||
|
// - Same minor version: always allowed (patch changes in either direction).
|
||||||
|
// - Single minor version upgrade (stored.minor+1 == current.minor): allowed.
|
||||||
|
// - Multi-minor upgrade or any minor downgrade: blocked with a fatal error.
|
||||||
|
func checkVersionUpgradePath(db *gorm.DB) error {
|
||||||
|
err := ensureDatabaseVersionTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
currentVersion := types.GetVersionInfo().Version
|
||||||
|
|
||||||
|
// Running binary has no real version — skip the check but
|
||||||
|
// preserve whatever version is already stored.
|
||||||
|
if isDev(currentVersion) {
|
||||||
|
storedVersion, err := getDatabaseVersion(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if storedVersion != "" && !isDev(storedVersion) {
|
||||||
|
log.Warn().
|
||||||
|
Str("database_version", storedVersion).
|
||||||
|
Msg("running a development build of headscale without a version number, " +
|
||||||
|
"database version check is skipped, the stored database version is preserved")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
storedVersion, err := getDatabaseVersion(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// No stored version — first run with this feature. Allow startup;
|
||||||
|
// the version will be stored after migrations succeed.
|
||||||
|
if storedVersion == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Previous run was an unversioned build — no meaningful comparison.
|
||||||
|
if isDev(storedVersion) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
current, err := parseVersion(currentVersion)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing current version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, err := parseVersion(storedVersion)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parsing stored database version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.Major != stored.Major {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"headscale version %s cannot be used with a database last used by %s: %w",
|
||||||
|
currentVersion, storedVersion, errVersionMajorChange,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
minorDiff := current.Minor - stored.Minor
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case minorDiff == 0:
|
||||||
|
// Same minor version — patch changes are always fine.
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case minorDiff == 1:
|
||||||
|
// Single minor version upgrade — allowed.
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case minorDiff > 1:
|
||||||
|
// Multi-minor upgrade — blocked.
|
||||||
|
return fmt.Errorf(
|
||||||
|
"headscale version %s cannot be used with a database last used by %s, "+
|
||||||
|
"upgrading more than one minor version at a time is not supported, "+
|
||||||
|
"please upgrade to the latest v%d.%d.x release first, then to %s, "+
|
||||||
|
"release page: https://github.com/juanfont/headscale/releases: %w",
|
||||||
|
currentVersion, storedVersion,
|
||||||
|
stored.Major, stored.Minor+1,
|
||||||
|
current.String(),
|
||||||
|
errVersionUpgrade,
|
||||||
|
)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// minorDiff < 0 — any minor downgrade is blocked.
|
||||||
|
return fmt.Errorf(
|
||||||
|
"headscale version %s cannot be used with a database last used by %s, "+
|
||||||
|
"downgrading to a previous minor version is not supported, "+
|
||||||
|
"release page: https://github.com/juanfont/headscale/releases: %w",
|
||||||
|
currentVersion, storedVersion,
|
||||||
|
errVersionDowngrade,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
318
hscontrol/db/versioncheck_test.go
Normal file
318
hscontrol/db/versioncheck_test.go
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/glebarez/sqlite"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want semver
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{input: "v0.25.0", want: semver{0, 25, 0}},
|
||||||
|
{input: "0.25.0", want: semver{0, 25, 0}},
|
||||||
|
{input: "v0.25.1", want: semver{0, 25, 1}},
|
||||||
|
{input: "v1.0.0", want: semver{1, 0, 0}},
|
||||||
|
{input: "v0.28.3", want: semver{0, 28, 3}},
|
||||||
|
// Pre-release suffixes stripped
|
||||||
|
{input: "v0.25.0-beta.1", want: semver{0, 25, 0}},
|
||||||
|
{input: "v0.25.0-rc1", want: semver{0, 25, 0}},
|
||||||
|
// Build metadata stripped
|
||||||
|
{input: "v0.25.0+build123", want: semver{0, 25, 0}},
|
||||||
|
{input: "v0.25.0-beta.1+build123", want: semver{0, 25, 0}},
|
||||||
|
// Invalid inputs
|
||||||
|
{input: "", wantErr: true},
|
||||||
|
{input: "dev", wantErr: true},
|
||||||
|
{input: "vfoo.bar.baz", wantErr: true},
|
||||||
|
{input: "v1.2", wantErr: true},
|
||||||
|
{input: "v1", wantErr: true},
|
||||||
|
{input: "not-a-version", wantErr: true},
|
||||||
|
{input: "v1.2.3.4", wantErr: true},
|
||||||
|
{input: "(devel)", wantErr: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got, err := parseVersion(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSemverString(t *testing.T) {
|
||||||
|
s := semver{0, 28, 3}
|
||||||
|
assert.Equal(t, "v0.28.3", s.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDev(t *testing.T) {
|
||||||
|
assert.True(t, isDev(""))
|
||||||
|
assert.True(t, isDev("dev"))
|
||||||
|
assert.True(t, isDev("(devel)"))
|
||||||
|
assert.False(t, isDev("v0.28.0"))
|
||||||
|
assert.False(t, isDev("0.28.0"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// versionTestDB creates an in-memory SQLite database with the
|
||||||
|
// database_versions table already bootstrapped.
|
||||||
|
func versionTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = ensureDatabaseVersionTable(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAndGetDatabaseVersion(t *testing.T) {
|
||||||
|
db := versionTestDB(t)
|
||||||
|
|
||||||
|
// Initially empty
|
||||||
|
v, err := getDatabaseVersion(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, v)
|
||||||
|
|
||||||
|
// Set a version
|
||||||
|
err = setDatabaseVersion(db, "v0.27.0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
v, err = getDatabaseVersion(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "v0.27.0", v)
|
||||||
|
|
||||||
|
// Update the version (upsert)
|
||||||
|
err = setDatabaseVersion(db, "v0.28.0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
v, err = getDatabaseVersion(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "v0.28.0", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureDatabaseVersionTableIdempotent(t *testing.T) {
|
||||||
|
db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Call twice — should not error
|
||||||
|
err = ensureDatabaseVersionTable(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = ensureDatabaseVersionTable(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCheckVersionUpgradePathDirect tests the version comparison logic
|
||||||
|
// by directly seeding the database, bypassing types.GetVersionInfo()
|
||||||
|
// (which returns "dev" in test environments and cannot be overridden).
|
||||||
|
func TestCheckVersionUpgradePathDirect(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
storedVersion string // empty means no row stored
|
||||||
|
currentVersion string
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
}{
|
||||||
|
// Fresh database (no stored version)
|
||||||
|
{
|
||||||
|
name: "fresh db allows any version",
|
||||||
|
storedVersion: "",
|
||||||
|
currentVersion: "v0.28.0",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Stored is dev
|
||||||
|
{
|
||||||
|
name: "real version over dev db",
|
||||||
|
storedVersion: "dev",
|
||||||
|
currentVersion: "v0.28.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "devel version in db",
|
||||||
|
storedVersion: "(devel)",
|
||||||
|
currentVersion: "v0.28.0",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Same version
|
||||||
|
{
|
||||||
|
name: "same version",
|
||||||
|
storedVersion: "v0.27.0",
|
||||||
|
currentVersion: "v0.27.0",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Patch changes within same minor
|
||||||
|
{
|
||||||
|
name: "patch upgrade",
|
||||||
|
storedVersion: "v0.27.0",
|
||||||
|
currentVersion: "v0.27.3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "patch downgrade within same minor",
|
||||||
|
storedVersion: "v0.27.3",
|
||||||
|
currentVersion: "v0.27.0",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Single minor upgrade
|
||||||
|
{
|
||||||
|
name: "single minor upgrade",
|
||||||
|
storedVersion: "v0.27.0",
|
||||||
|
currentVersion: "v0.28.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single minor upgrade with different patches",
|
||||||
|
storedVersion: "v0.27.3",
|
||||||
|
currentVersion: "v0.28.1",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Multi-minor upgrade (blocked)
|
||||||
|
{
|
||||||
|
name: "two minor versions ahead",
|
||||||
|
storedVersion: "v0.25.0",
|
||||||
|
currentVersion: "v0.27.0",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "latest v0.26.x",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three minor versions ahead",
|
||||||
|
storedVersion: "v0.25.0",
|
||||||
|
currentVersion: "v0.28.0",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "latest v0.26.x",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Minor downgrades (blocked)
|
||||||
|
{
|
||||||
|
name: "single minor downgrade",
|
||||||
|
storedVersion: "v0.28.0",
|
||||||
|
currentVersion: "v0.27.0",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "downgrading",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi minor downgrade",
|
||||||
|
storedVersion: "v0.28.0",
|
||||||
|
currentVersion: "v0.25.0",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "downgrading",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Major version mismatch
|
||||||
|
{
|
||||||
|
name: "major version upgrade",
|
||||||
|
storedVersion: "v0.28.0",
|
||||||
|
currentVersion: "v1.0.0",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "major version",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "major version downgrade",
|
||||||
|
storedVersion: "v1.0.0",
|
||||||
|
currentVersion: "v0.28.0",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "major version",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Pre-release versions
|
||||||
|
{
|
||||||
|
name: "pre-release single minor upgrade",
|
||||||
|
storedVersion: "v0.27.0",
|
||||||
|
currentVersion: "v0.28.0-beta.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pre-release multi minor upgrade blocked",
|
||||||
|
storedVersion: "v0.25.0",
|
||||||
|
currentVersion: "v0.27.0-rc1",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "latest v0.26.x",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db := versionTestDB(t)
|
||||||
|
|
||||||
|
// Seed the stored version if provided
|
||||||
|
if tt.storedVersion != "" {
|
||||||
|
err := setDatabaseVersion(db, tt.storedVersion)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := checkVersionUpgradePathFromVersions(db, tt.currentVersion)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
if tt.errContains != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errContains)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkVersionUpgradePathFromVersions is a test helper that runs the
|
||||||
|
// version comparison logic with a specific currentVersion string,
|
||||||
|
// bypassing types.GetVersionInfo(). It replicates the logic from
|
||||||
|
// checkVersionUpgradePath but accepts the version as a parameter.
|
||||||
|
func checkVersionUpgradePathFromVersions(db *gorm.DB, currentVersion string) error {
|
||||||
|
if isDev(currentVersion) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
storedVersion, err := getDatabaseVersion(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if storedVersion == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDev(storedVersion) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
current, err := parseVersion(currentVersion)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, err := parseVersion(storedVersion)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.Major != stored.Major {
|
||||||
|
return errVersionMajorChange
|
||||||
|
}
|
||||||
|
|
||||||
|
minorDiff := current.Minor - stored.Minor
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case minorDiff == 0:
|
||||||
|
return nil
|
||||||
|
case minorDiff == 1:
|
||||||
|
return nil
|
||||||
|
case minorDiff > 1:
|
||||||
|
return fmt.Errorf(
|
||||||
|
"please upgrade to the latest v%d.%d.x release first: %w",
|
||||||
|
stored.Major, stored.Minor+1,
|
||||||
|
errVersionUpgrade,
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("downgrading: %w", errVersionDowngrade)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
if wantsJSON {
|
if wantsJSON {
|
||||||
overview := h.state.DebugOverviewJSON()
|
overview := h.state.DebugOverviewJSON()
|
||||||
|
|
||||||
overviewJSON, err := json.MarshalIndent(overview, "", " ")
|
overviewJSON, err := json.MarshalIndent(overview, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(overviewJSON)
|
_, _ = w.Write(overviewJSON)
|
||||||
} else {
|
} else {
|
||||||
// Default to text/plain for backward compatibility
|
// Default to text/plain for backward compatibility
|
||||||
overview := h.state.DebugOverview()
|
overview := h.state.DebugOverview()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(overview))
|
_, _ = w.Write([]byte(overview))
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Configuration endpoint
|
// Configuration endpoint
|
||||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
config := h.state.DebugConfig()
|
config := h.state.DebugConfig()
|
||||||
|
|
||||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(configJSON)
|
_, _ = w.Write(configJSON)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Policy endpoint
|
// Policy endpoint
|
||||||
@@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
} else {
|
} else {
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(policy))
|
_, _ = w.Write([]byte(policy))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Filter rules endpoint
|
// Filter rules endpoint
|
||||||
@@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(filterJSON)
|
_, _ = w.Write(filterJSON)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// SSH policies endpoint
|
// SSH policies endpoint
|
||||||
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
sshPolicies := h.state.DebugSSHPolicies()
|
sshPolicies := h.state.DebugSSHPolicies()
|
||||||
|
|
||||||
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
|
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(sshJSON)
|
_, _ = w.Write(sshJSON)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// DERP map endpoint
|
// DERP map endpoint
|
||||||
@@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
if wantsJSON {
|
if wantsJSON {
|
||||||
derpInfo := h.state.DebugDERPJSON()
|
derpInfo := h.state.DebugDERPJSON()
|
||||||
|
|
||||||
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
|
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(derpJSON)
|
_, _ = w.Write(derpJSON)
|
||||||
} else {
|
} else {
|
||||||
// Default to text/plain for backward compatibility
|
// Default to text/plain for backward compatibility
|
||||||
derpInfo := h.state.DebugDERPMap()
|
derpInfo := h.state.DebugDERPMap()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(derpInfo))
|
_, _ = w.Write([]byte(derpInfo))
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
if wantsJSON {
|
if wantsJSON {
|
||||||
nodeStoreNodes := h.state.DebugNodeStoreJSON()
|
nodeStoreNodes := h.state.DebugNodeStoreJSON()
|
||||||
|
|
||||||
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
|
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(nodeStoreJSON)
|
_, _ = w.Write(nodeStoreJSON)
|
||||||
} else {
|
} else {
|
||||||
// Default to text/plain for backward compatibility
|
// Default to text/plain for backward compatibility
|
||||||
nodeStoreInfo := h.state.DebugNodeStore()
|
nodeStoreInfo := h.state.DebugNodeStore()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(nodeStoreInfo))
|
_, _ = w.Write([]byte(nodeStoreInfo))
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Registration cache endpoint
|
// Registration cache endpoint
|
||||||
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
cacheInfo := h.state.DebugRegistrationCache()
|
cacheInfo := h.state.DebugRegistrationCache()
|
||||||
|
|
||||||
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
|
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(cacheJSON)
|
_, _ = w.Write(cacheJSON)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Routes endpoint
|
// Routes endpoint
|
||||||
@@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
if wantsJSON {
|
if wantsJSON {
|
||||||
routes := h.state.DebugRoutes()
|
routes := h.state.DebugRoutes()
|
||||||
|
|
||||||
routesJSON, err := json.MarshalIndent(routes, "", " ")
|
routesJSON, err := json.MarshalIndent(routes, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(routesJSON)
|
_, _ = w.Write(routesJSON)
|
||||||
} else {
|
} else {
|
||||||
// Default to text/plain for backward compatibility
|
// Default to text/plain for backward compatibility
|
||||||
routes := h.state.DebugRoutesString()
|
routes := h.state.DebugRoutesString()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(routes))
|
_, _ = w.Write([]byte(routes))
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
if wantsJSON {
|
if wantsJSON {
|
||||||
policyManagerInfo := h.state.DebugPolicyManagerJSON()
|
policyManagerInfo := h.state.DebugPolicyManagerJSON()
|
||||||
|
|
||||||
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
|
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(policyManagerJSON)
|
_, _ = w.Write(policyManagerJSON)
|
||||||
} else {
|
} else {
|
||||||
// Default to text/plain for backward compatibility
|
// Default to text/plain for backward compatibility
|
||||||
policyManagerInfo := h.state.DebugPolicyManager()
|
policyManagerInfo := h.state.DebugPolicyManager()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(policyManagerInfo))
|
_, _ = w.Write([]byte(policyManagerInfo))
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
if res == nil {
|
if res == nil {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
_, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
httpError(w, err)
|
httpError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(resJSON)
|
_, _ = w.Write(resJSON)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Batcher endpoint
|
// Batcher endpoint
|
||||||
@@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write(batcherJSON)
|
_, _ = w.Write(batcherJSON)
|
||||||
} else {
|
} else {
|
||||||
// Default to text/plain for backward compatibility
|
// Default to text/plain for backward compatibility
|
||||||
batcherInfo := h.debugBatcher()
|
batcherInfo := h.debugBatcher()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(batcherInfo))
|
_, _ = w.Write([]byte(batcherInfo))
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
|
|||||||
activeConnections: info.ActiveConnections,
|
activeConnections: info.ActiveConnections,
|
||||||
})
|
})
|
||||||
totalNodes++
|
totalNodes++
|
||||||
|
|
||||||
if info.Connected {
|
if info.Connected {
|
||||||
connectedCount++
|
connectedCount++
|
||||||
}
|
}
|
||||||
@@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
|
|||||||
activeConnections: 0,
|
activeConnections: 0,
|
||||||
})
|
})
|
||||||
totalNodes++
|
totalNodes++
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
connectedCount++
|
connectedCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
|||||||
ActiveConnections: 0,
|
ActiveConnections: 0,
|
||||||
}
|
}
|
||||||
info.TotalNodes++
|
info.TotalNodes++
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,11 +28,14 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer derpFile.Close()
|
defer derpFile.Close()
|
||||||
|
|
||||||
var derpMap tailcfg.DERPMap
|
var derpMap tailcfg.DERPMap
|
||||||
|
|
||||||
b, err := io.ReadAll(derpFile)
|
b, err := io.ReadAll(derpFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = yaml.Unmarshal(b, &derpMap)
|
err = yaml.Unmarshal(b, &derpMap)
|
||||||
|
|
||||||
return &derpMap, err
|
return &derpMap, err
|
||||||
@@ -57,12 +60,14 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var derpMap tailcfg.DERPMap
|
var derpMap tailcfg.DERPMap
|
||||||
|
|
||||||
err = json.Unmarshal(body, &derpMap)
|
err = json.Unmarshal(body, &derpMap)
|
||||||
|
|
||||||
return &derpMap, err
|
return &derpMap, err
|
||||||
@@ -134,6 +139,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
|
|||||||
for id := range dm.Regions {
|
for id := range dm.Regions {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(ids)
|
slices.Sort(ids)
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
@@ -160,16 +166,18 @@ func derpRandom() *rand.Rand {
|
|||||||
|
|
||||||
derpRandomOnce.Do(func() {
|
derpRandomOnce.Do(func() {
|
||||||
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
|
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
|
||||||
rnd := rand.New(rand.NewSource(0))
|
rnd := rand.New(rand.NewSource(0)) //nolint:gosec // weak random is fine for DERP scrambling
|
||||||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
|
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) //nolint:gosec // safe conversion
|
||||||
derpRandomInst = rnd
|
derpRandomInst = rnd
|
||||||
})
|
})
|
||||||
|
|
||||||
return derpRandomInst
|
return derpRandomInst
|
||||||
}
|
}
|
||||||
|
|
||||||
func resetDerpRandomForTesting() {
|
func resetDerpRandomForTesting() {
|
||||||
derpRandomMu.Lock()
|
derpRandomMu.Lock()
|
||||||
defer derpRandomMu.Unlock()
|
defer derpRandomMu.Unlock()
|
||||||
|
|
||||||
derpRandomOnce = sync.Once{}
|
derpRandomOnce = sync.Once{}
|
||||||
derpRandomInst = nil
|
derpRandomInst = nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
viper.Set("dns.base_domain", tt.baseDomain)
|
viper.Set("dns.base_domain", tt.baseDomain)
|
||||||
|
|
||||||
defer viper.Reset()
|
defer viper.Reset()
|
||||||
|
|
||||||
resetDerpRandomForTesting()
|
resetDerpRandomForTesting()
|
||||||
|
|
||||||
testMap := tt.derpMap.View().AsStruct()
|
testMap := tt.derpMap.View().AsStruct()
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func NewDERPServer(
|
|||||||
derpKey key.NodePrivate,
|
derpKey key.NodePrivate,
|
||||||
cfg *types.DERPConfig,
|
cfg *types.DERPConfig,
|
||||||
) (*DERPServer, error) {
|
) (*DERPServer, error) {
|
||||||
log.Trace().Caller().Msg("Creating new embedded DERP server")
|
log.Trace().Caller().Msg("creating new embedded DERP server")
|
||||||
server := derpserver.New(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains
|
server := derpserver.New(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains
|
||||||
|
|
||||||
if cfg.ServerVerifyClients {
|
if cfg.ServerVerifyClients {
|
||||||
@@ -75,9 +75,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return tailcfg.DERPRegion{}, err
|
return tailcfg.DERPRegion{}, err
|
||||||
}
|
}
|
||||||
var host string
|
|
||||||
var port int
|
var (
|
||||||
var portStr string
|
host string
|
||||||
|
port int
|
||||||
|
portStr string
|
||||||
|
)
|
||||||
|
|
||||||
// Extract hostname and port from URL
|
// Extract hostname and port from URL
|
||||||
host, portStr, err = net.SplitHostPort(serverURL.Host)
|
host, portStr, err = net.SplitHostPort(serverURL.Host)
|
||||||
@@ -98,13 +101,13 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||||||
|
|
||||||
// If debug flag is set, resolve hostname to IP address
|
// If debug flag is set, resolve hostname to IP address
|
||||||
if debugUseDERPIP {
|
if debugUseDERPIP {
|
||||||
ips, err := net.LookupIP(host)
|
ips, err := new(net.Resolver).LookupIPAddr(context.Background(), host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host)
|
log.Error().Caller().Err(err).Msgf("failed to resolve DERP hostname %s to IP, using hostname", host)
|
||||||
} else if len(ips) > 0 {
|
} else if len(ips) > 0 {
|
||||||
// Use the first IP address
|
// Use the first IP address
|
||||||
ipStr := ips[0].String()
|
ipStr := ips[0].IP.String()
|
||||||
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr)
|
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: resolved %s to %s", host, ipStr)
|
||||||
host = ipStr
|
host = ipStr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -130,14 +133,16 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return tailcfg.DERPRegion{}, err
|
return tailcfg.DERPRegion{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
portSTUN, err := strconv.Atoi(portSTUNStr)
|
portSTUN, err := strconv.Atoi(portSTUNStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tailcfg.DERPRegion{}, err
|
return tailcfg.DERPRegion{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
localDERPregion.Nodes[0].STUNPort = portSTUN
|
localDERPregion.Nodes[0].STUNPort = portSTUN
|
||||||
|
|
||||||
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
|
log.Info().Caller().Msgf("derp region: %+v", localDERPregion)
|
||||||
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0])
|
log.Info().Caller().Msgf("derp nodes[0]: %+v", localDERPregion.Nodes[0])
|
||||||
|
|
||||||
return localDERPregion, nil
|
return localDERPregion, nil
|
||||||
}
|
}
|
||||||
@@ -155,8 +160,10 @@ func (d *DERPServer) DERPHandler(
|
|||||||
Caller().
|
Caller().
|
||||||
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/plain")
|
writer.Header().Set("Content-Type", "text/plain")
|
||||||
writer.WriteHeader(http.StatusUpgradeRequired)
|
writer.WriteHeader(http.StatusUpgradeRequired)
|
||||||
|
|
||||||
_, err := writer.Write([]byte("DERP requires connection upgrade"))
|
_, err := writer.Write([]byte("DERP requires connection upgrade"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -206,6 +213,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer websocketConn.Close(websocket.StatusInternalError, "closing")
|
defer websocketConn.Close(websocket.StatusInternalError, "closing")
|
||||||
|
|
||||||
if websocketConn.Subprotocol() != "derp" {
|
if websocketConn.Subprotocol() != "derp" {
|
||||||
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
|
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
|
||||||
|
|
||||||
@@ -222,9 +230,10 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
hijacker, ok := writer.(http.Hijacker)
|
hijacker, ok := writer.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error().Caller().Msg("DERP requires Hijacker interface from Gin")
|
log.Error().Caller().Msg("derp requires Hijacker interface from Gin")
|
||||||
writer.Header().Set("Content-Type", "text/plain")
|
writer.Header().Set("Content-Type", "text/plain")
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
writer.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
_, err := writer.Write([]byte("HTTP does not support general TCP support"))
|
_, err := writer.Write([]byte("HTTP does not support general TCP support"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -238,9 +247,10 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
netConn, conn, err := hijacker.Hijack()
|
netConn, conn, err := hijacker.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Caller().Err(err).Msgf("Hijack failed")
|
log.Error().Caller().Err(err).Msgf("hijack failed")
|
||||||
writer.Header().Set("Content-Type", "text/plain")
|
writer.Header().Set("Content-Type", "text/plain")
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
writer.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
_, err = writer.Write([]byte("HTTP does not support general TCP support"))
|
_, err = writer.Write([]byte("HTTP does not support general TCP support"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -251,7 +261,8 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
|
|
||||||
|
log.Trace().Caller().Msgf("hijacked connection from %v", req.RemoteAddr)
|
||||||
|
|
||||||
if !fastStart {
|
if !fastStart {
|
||||||
pubKey := d.key.Public()
|
pubKey := d.key.Public()
|
||||||
@@ -280,6 +291,7 @@ func DERPProbeHandler(
|
|||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
default:
|
default:
|
||||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
|
||||||
_, err := writer.Write([]byte("bogus probe method"))
|
_, err := writer.Write([]byte("bogus probe method"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -309,9 +321,11 @@ func DERPBootstrapDNSHandler(
|
|||||||
|
|
||||||
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var resolver net.Resolver
|
var resolver net.Resolver
|
||||||
for _, region := range derpMap.Regions().All() {
|
|
||||||
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
|
for _, region := range derpMap.Regions().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
|
||||||
|
for _, node := range region.Nodes().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
|
||||||
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
|
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
@@ -321,11 +335,14 @@ func DERPBootstrapDNSHandler(
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsEntries[node.HostName()] = addrs
|
dnsEntries[node.HostName()] = addrs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json")
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
err := json.NewEncoder(writer).Encode(dnsEntries)
|
err := json.NewEncoder(writer).Encode(dnsEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -338,33 +355,37 @@ func DERPBootstrapDNSHandler(
|
|||||||
|
|
||||||
// ServeSTUN starts a STUN server on the configured addr.
|
// ServeSTUN starts a STUN server on the configured addr.
|
||||||
func (d *DERPServer) ServeSTUN() {
|
func (d *DERPServer) ServeSTUN() {
|
||||||
packetConn, err := net.ListenPacket("udp", d.cfg.STUNAddr)
|
packetConn, err := new(net.ListenConfig).ListenPacket(context.Background(), "udp", d.cfg.STUNAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Msgf("failed to open STUN listener: %v", err)
|
log.Fatal().Msgf("failed to open STUN listener: %v", err)
|
||||||
}
|
}
|
||||||
log.Info().Msgf("STUN server started at %s", packetConn.LocalAddr())
|
|
||||||
|
log.Info().Msgf("stun server started at %s", packetConn.LocalAddr())
|
||||||
|
|
||||||
udpConn, ok := packetConn.(*net.UDPConn)
|
udpConn, ok := packetConn.(*net.UDPConn)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Fatal().Msg("STUN listener is not a UDP listener")
|
log.Fatal().Msg("stun listener is not a UDP listener")
|
||||||
}
|
}
|
||||||
|
|
||||||
serverSTUNListener(context.Background(), udpConn)
|
serverSTUNListener(context.Background(), udpConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
||||||
var buf [64 << 10]byte
|
|
||||||
var (
|
var (
|
||||||
|
buf [64 << 10]byte
|
||||||
bytesRead int
|
bytesRead int
|
||||||
udpAddr *net.UDPAddr
|
udpAddr *net.UDPAddr
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:])
|
bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Error().Caller().Err(err).Msgf("STUN ReadFrom")
|
|
||||||
|
log.Error().Caller().Err(err).Msgf("stun ReadFrom")
|
||||||
|
|
||||||
// Rate limit error logging - wait before retrying, but respect context cancellation
|
// Rate limit error logging - wait before retrying, but respect context cancellation
|
||||||
select {
|
select {
|
||||||
@@ -375,25 +396,29 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Trace().Caller().Msgf("STUN request from %v", udpAddr)
|
|
||||||
|
log.Trace().Caller().Msgf("stun request from %v", udpAddr)
|
||||||
|
|
||||||
pkt := buf[:bytesRead]
|
pkt := buf[:bytesRead]
|
||||||
if !stun.Is(pkt) {
|
if !stun.Is(pkt) {
|
||||||
log.Trace().Caller().Msgf("UDP packet is not STUN")
|
log.Trace().Caller().Msgf("udp packet is not stun")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
txid, err := stun.ParseBindingRequest(pkt)
|
txid, err := stun.ParseBindingRequest(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().Caller().Err(err).Msgf("STUN parse error")
|
log.Trace().Caller().Err(err).Msgf("stun parse error")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, _ := netip.AddrFromSlice(udpAddr.IP)
|
addr, _ := netip.AddrFromSlice(udpAddr.IP)
|
||||||
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port)))
|
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) //nolint:gosec // port is always <=65535
|
||||||
|
|
||||||
_, err = packetConn.WriteTo(res, udpAddr)
|
_, err = packetConn.WriteTo(res, udpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().Caller().Err(err).Msgf("Issue writing to UDP")
|
log.Trace().Caller().Err(err).Msgf("issue writing to UDP")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -412,8 +437,10 @@ type DERPVerifyTransport struct {
|
|||||||
|
|
||||||
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
if err := t.handleVerifyRequest(req, buf); err != nil {
|
|
||||||
log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ")
|
err := t.handleVerifyRequest(req, buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Caller().Err(err).Msg("failed to handle client verify request")
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -15,6 +16,9 @@ import (
|
|||||||
"tailscale.com/util/set"
|
"tailscale.com/util/set"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrPathIsDirectory is returned when a directory path is provided where a file is expected.
|
||||||
|
var ErrPathIsDirectory = errors.New("path is a directory, only file is supported")
|
||||||
|
|
||||||
type ExtraRecordsMan struct {
|
type ExtraRecordsMan struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
records set.Set[tailcfg.DNSRecord]
|
records set.Set[tailcfg.DNSRecord]
|
||||||
@@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if fi.IsDir() {
|
if fi.IsDir() {
|
||||||
return nil, fmt.Errorf("path is a directory, only file is supported: %s", path)
|
return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
records, hash, err := readExtraRecordsFromPath(path)
|
records, hash, err := readExtraRecordsFromPath(path)
|
||||||
@@ -85,19 +89,22 @@ func (e *ExtraRecordsMan) Run() {
|
|||||||
log.Error().Caller().Msgf("file watcher event channel closing")
|
log.Error().Caller().Msgf("file watcher event channel closing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch event.Op {
|
switch event.Op {
|
||||||
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
|
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
|
||||||
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
|
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
|
||||||
|
|
||||||
if event.Name != e.path {
|
if event.Name != e.path {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
e.updateRecords()
|
e.updateRecords()
|
||||||
|
|
||||||
// If a file is removed or renamed, fsnotify will loose track of it
|
// If a file is removed or renamed, fsnotify will loose track of it
|
||||||
// and not watch it. We will therefore attempt to re-add it with a backoff.
|
// and not watch it. We will therefore attempt to re-add it with a backoff.
|
||||||
case fsnotify.Remove, fsnotify.Rename:
|
case fsnotify.Remove, fsnotify.Rename:
|
||||||
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
|
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
|
||||||
if _, err := os.Stat(e.path); err != nil {
|
if _, err := os.Stat(e.path); err != nil { //nolint:noinlineerr
|
||||||
return struct{}{}, err
|
return struct{}{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,6 +130,7 @@ func (e *ExtraRecordsMan) Run() {
|
|||||||
log.Error().Caller().Msgf("file watcher error channel closing")
|
log.Error().Caller().Msgf("file watcher error channel closing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
|
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,6 +173,7 @@ func (e *ExtraRecordsMan) updateRecords() {
|
|||||||
e.hashes[e.path] = newHash
|
e.hashes[e.path] = newHash
|
||||||
|
|
||||||
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
|
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
|
||||||
|
|
||||||
e.updateCh <- e.records.Slice()
|
e.updateCh <- e.records.Slice()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,6 +192,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
var records []tailcfg.DNSRecord
|
var records []tailcfg.DNSRecord
|
||||||
|
|
||||||
err = json.Unmarshal(b, &records)
|
err = json.Unmarshal(b, &records)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)
|
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/state"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
)
|
)
|
||||||
|
|
||||||
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
|
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
|
||||||
@@ -54,7 +55,7 @@ func (api headscaleV1APIServer) CreateUser(
|
|||||||
}
|
}
|
||||||
user, policyChanged, err := api.h.state.CreateUser(newUser)
|
user, policyChanged, err := api.h.state.CreateUser(newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
return nil, status.Errorf(codes.Internal, "creating user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser returns a policy change response if the user creation affected policy.
|
// CreateUser returns a policy change response if the user creation affected policy.
|
||||||
@@ -235,16 +236,16 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||||||
// Generate ephemeral registration key for tracking this registration flow in logs
|
// Generate ephemeral registration key for tracking this registration flow in logs
|
||||||
registrationKey, err := util.GenerateRegistrationKey()
|
registrationKey, err := util.GenerateRegistrationKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to generate registration key")
|
log.Warn().Err(err).Msg("failed to generate registration key")
|
||||||
registrationKey = "" // Continue without key if generation fails
|
registrationKey = "" // Continue without key if generation fails
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("user", request.GetUser()).
|
Str(zf.UserName, request.GetUser()).
|
||||||
Str("registration_id", request.GetKey()).
|
Str(zf.RegistrationID, request.GetKey()).
|
||||||
Str("registration_key", registrationKey).
|
Str(zf.RegistrationKey, registrationKey).
|
||||||
Msg("Registering node")
|
Msg("registering node")
|
||||||
|
|
||||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -264,17 +265,16 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("registration_key", registrationKey).
|
Str(zf.RegistrationKey, registrationKey).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to register node")
|
Msg("failed to register node")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("registration_key", registrationKey).
|
Str(zf.RegistrationKey, registrationKey).
|
||||||
Str("node_id", fmt.Sprintf("%d", node.ID())).
|
EmbedObject(node).
|
||||||
Str("hostname", node.Hostname()).
|
Msg("node registered successfully")
|
||||||
Msg("Node registered successfully")
|
|
||||||
|
|
||||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||||
// dependency here.
|
// dependency here.
|
||||||
@@ -355,9 +355,9 @@ func (api headscaleV1APIServer) SetTags(
|
|||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname()).
|
EmbedObject(node).
|
||||||
Strs("tags", request.GetTags()).
|
Strs("tags", request.GetTags()).
|
||||||
Msg("Changing tags of node")
|
Msg("changing tags of node")
|
||||||
|
|
||||||
return &v1.SetTagsResponse{Node: node.Proto()}, nil
|
return &v1.SetTagsResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
@@ -368,7 +368,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
) (*v1.SetApprovedRoutesResponse, error) {
|
) (*v1.SetApprovedRoutesResponse, error) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
Uint64("node.id", request.GetNodeId()).
|
Uint64(zf.NodeID, request.GetNodeId()).
|
||||||
Strs("requestedRoutes", request.GetRoutes()).
|
Strs("requestedRoutes", request.GetRoutes()).
|
||||||
Msg("gRPC SetApprovedRoutes called")
|
Msg("gRPC SetApprovedRoutes called")
|
||||||
|
|
||||||
@@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
newApproved = append(newApproved, prefix)
|
newApproved = append(newApproved, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tsaddr.SortPrefixes(newApproved)
|
slices.SortFunc(newApproved, netip.Prefix.Compare)
|
||||||
newApproved = slices.Compact(newApproved)
|
newApproved = slices.Compact(newApproved)
|
||||||
|
|
||||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||||
@@ -406,7 +406,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Caller().
|
Caller().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||||
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
|
||||||
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
Strs("finalSubnetRoutes", proto.SubnetRoutes).
|
||||||
@@ -423,7 +423,7 @@ func validateTag(tag string) error {
|
|||||||
return errors.New("tag should be lowercase")
|
return errors.New("tag should be lowercase")
|
||||||
}
|
}
|
||||||
if len(strings.Fields(tag)) > 1 {
|
if len(strings.Fields(tag)) > 1 {
|
||||||
return errors.New("tag should not contains space")
|
return errors.New("tags must not contain spaces")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -466,8 +466,8 @@ func (api headscaleV1APIServer) ExpireNode(
|
|||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname()).
|
EmbedObject(node).
|
||||||
Time("expiry", *node.AsStruct().Expiry).
|
Time(zf.ExpiresAt, *node.AsStruct().Expiry).
|
||||||
Msg("node expired")
|
Msg("node expired")
|
||||||
|
|
||||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||||
@@ -487,8 +487,8 @@ func (api headscaleV1APIServer) RenameNode(
|
|||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node", node.Hostname()).
|
EmbedObject(node).
|
||||||
Str("new_name", request.GetNewName()).
|
Str(zf.NewName, request.GetNewName()).
|
||||||
Msg("node renamed")
|
Msg("node renamed")
|
||||||
|
|
||||||
return &v1.RenameNodeResponse{Node: node.Proto()}, nil
|
return &v1.RenameNodeResponse{Node: node.Proto()}, nil
|
||||||
@@ -546,7 +546,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.BackfillNodeIPsRequest,
|
request *v1.BackfillNodeIPsRequest,
|
||||||
) (*v1.BackfillNodeIPsResponse, error) {
|
) (*v1.BackfillNodeIPsResponse, error) {
|
||||||
log.Trace().Caller().Msg("Backfill called")
|
log.Trace().Caller().Msg("backfill called")
|
||||||
|
|
||||||
if !request.Confirmed {
|
if !request.Confirmed {
|
||||||
return nil, errors.New("not confirmed, aborting")
|
return nil, errors.New("not confirmed, aborting")
|
||||||
@@ -817,13 +817,13 @@ func (api headscaleV1APIServer) Health(
|
|||||||
response := &v1.HealthResponse{}
|
response := &v1.HealthResponse{}
|
||||||
|
|
||||||
if err := api.h.state.PingDB(ctx); err != nil {
|
if err := api.h.state.PingDB(ctx); err != nil {
|
||||||
healthErr = fmt.Errorf("database ping failed: %w", err)
|
healthErr = fmt.Errorf("pinging database: %w", err)
|
||||||
} else {
|
} else {
|
||||||
response.DatabaseConnectivity = true
|
response.DatabaseConnectivity = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if healthErr != nil {
|
if healthErr != nil {
|
||||||
log.Error().Err(healthErr).Msg("Health check failed")
|
log.Error().Err(healthErr).Msg("health check failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, healthErr
|
return response, healthErr
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ func Test_validateTag(t *testing.T) {
|
|||||||
type args struct {
|
type args struct {
|
||||||
tag string
|
tag string
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
@@ -45,7 +46,8 @@ func Test_validateTag(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr {
|
err := validateTag(tt.args.tag)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// The CapabilityVersion is used by Tailscale clients to indicate
|
// NoiseCapabilityVersion is used by Tailscale clients to indicate
|
||||||
// their codebase version. Tailscale clients can communicate over TS2021
|
// their codebase version. Tailscale clients can communicate over TS2021
|
||||||
// from CapabilityVersion 28, but we only have good support for it
|
// from CapabilityVersion 28, but we only have good support for it
|
||||||
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
|
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
|
||||||
@@ -36,8 +36,7 @@ const (
|
|||||||
|
|
||||||
// httpError logs an error and sends an HTTP error response with the given.
|
// httpError logs an error and sends an HTTP error response with the given.
|
||||||
func httpError(w http.ResponseWriter, err error) {
|
func httpError(w http.ResponseWriter, err error) {
|
||||||
var herr HTTPError
|
if herr, ok := errors.AsType[HTTPError](err); ok {
|
||||||
if errors.As(err, &herr) {
|
|
||||||
http.Error(w, herr.Msg, herr.Code)
|
http.Error(w, herr.Msg, herr.Code)
|
||||||
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
|
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
|
||||||
} else {
|
} else {
|
||||||
@@ -56,7 +55,7 @@ type HTTPError struct {
|
|||||||
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
|
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
|
||||||
func (e HTTPError) Unwrap() error { return e.Err }
|
func (e HTTPError) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
// Error returns an HTTPError containing the given information.
|
// NewHTTPError returns an HTTPError containing the given information.
|
||||||
func NewHTTPError(code int, msg string, err error) HTTPError {
|
func NewHTTPError(code int, msg string, err error) HTTPError {
|
||||||
return HTTPError{Code: code, Msg: msg, Err: err}
|
return HTTPError{Code: code, Msg: msg, Err: err}
|
||||||
}
|
}
|
||||||
@@ -64,7 +63,7 @@ func NewHTTPError(code int, msg string, err error) HTTPError {
|
|||||||
var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil)
|
var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil)
|
||||||
|
|
||||||
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
||||||
"machines registered with CLI does not support expire",
|
"machines registered with CLI do not support expiry",
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) {
|
func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) {
|
||||||
@@ -76,7 +75,7 @@ func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error
|
|||||||
|
|
||||||
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
|
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err))
|
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("parsing capability version: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
||||||
@@ -88,12 +87,12 @@ func (h *Headscale) handleVerifyRequest(
|
|||||||
) error {
|
) error {
|
||||||
body, err := io.ReadAll(req.Body)
|
body, err := io.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot read request body: %w", err)
|
return fmt.Errorf("reading request body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
||||||
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
|
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { //nolint:noinlineerr
|
||||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
|
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("parsing DERP client request: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes := h.state.ListNodes()
|
nodes := h.state.ListNodes()
|
||||||
@@ -155,7 +154,11 @@ func (h *Headscale) KeyHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json")
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(writer).Encode(resp)
|
|
||||||
|
err := json.NewEncoder(writer).Encode(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to encode public key response")
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -180,8 +183,12 @@ func (h *Headscale) HealthHandler(
|
|||||||
res.Status = "fail"
|
res.Status = "fail"
|
||||||
}
|
}
|
||||||
|
|
||||||
json.NewEncoder(writer).Encode(res)
|
encErr := json.NewEncoder(writer).Encode(res)
|
||||||
|
if encErr != nil {
|
||||||
|
log.Error().Err(encErr).Msg("failed to encode health response")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.state.PingDB(req.Context())
|
err := h.state.PingDB(req.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respond(err)
|
respond(err)
|
||||||
@@ -218,6 +225,7 @@ func (h *Headscale) VersionHandler(
|
|||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
versionInfo := types.GetVersionInfo()
|
versionInfo := types.GetVersionInfo()
|
||||||
|
|
||||||
err := json.NewEncoder(writer).Encode(versionInfo)
|
err := json.NewEncoder(writer).Encode(versionInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -244,7 +252,7 @@ func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
|
|||||||
registrationId.String())
|
registrationId.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
// RegisterHandler shows a simple message in the browser to point to the CLI
|
||||||
// Listens in /register/:registration_id.
|
// Listens in /register/:registration_id.
|
||||||
//
|
//
|
||||||
// This is not part of the Tailscale control API, as we could send whatever URL
|
// This is not part of the Tailscale control API, as we could send whatever URL
|
||||||
@@ -267,7 +275,11 @@ func (a *AuthProviderWeb) RegisterHandler(
|
|||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
|
||||||
|
_, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to write register response")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FaviconHandler(writer http.ResponseWriter, req *http.Request) {
|
func FaviconHandler(writer http.ResponseWriter, req *http.Request) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/state"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
@@ -15,6 +16,14 @@ import (
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Mapper errors.
|
||||||
|
var (
|
||||||
|
ErrInvalidNodeID = errors.New("invalid nodeID")
|
||||||
|
ErrMapperNil = errors.New("mapper is nil")
|
||||||
|
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
||||||
|
ErrNodeNotFoundMapper = errors.New("node not found")
|
||||||
|
)
|
||||||
|
|
||||||
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
Namespace: "headscale",
|
Namespace: "headscale",
|
||||||
Name: "mapresponse_generated_total",
|
Name: "mapresponse_generated_total",
|
||||||
@@ -80,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nodeID == 0 {
|
if nodeID == 0 {
|
||||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mapper == nil {
|
if mapper == nil {
|
||||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle self-only responses
|
// Handle self-only responses
|
||||||
@@ -135,12 +144,12 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
|||||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
||||||
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
||||||
if nc == nil {
|
if nc == nil {
|
||||||
return errors.New("nodeConnection is nil")
|
return ErrNodeConnectionNil
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeID := nc.nodeID()
|
nodeID := nc.nodeID()
|
||||||
|
|
||||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
|
log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started")
|
||||||
|
|
||||||
data, err := generateMapResponse(nc, mapper, r)
|
data, err := generateMapResponse(nc, mapper, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package mapper
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -10,13 +11,20 @@ import (
|
|||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errConnectionClosed = errors.New("connection channel already closed")
|
// LockFreeBatcher errors.
|
||||||
|
var (
|
||||||
|
errConnectionClosed = errors.New("connection channel already closed")
|
||||||
|
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
|
||||||
|
ErrBatcherShuttingDown = errors.New("batcher shutting down")
|
||||||
|
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
|
||||||
|
)
|
||||||
|
|
||||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||||
type LockFreeBatcher struct {
|
type LockFreeBatcher struct {
|
||||||
@@ -48,6 +56,7 @@ type LockFreeBatcher struct {
|
|||||||
// and notifies other nodes that this node has come online.
|
// and notifies other nodes that this node has come online.
|
||||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
|
||||||
addNodeStart := time.Now()
|
addNodeStart := time.Now()
|
||||||
|
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
||||||
|
|
||||||
// Generate connection ID
|
// Generate connection ID
|
||||||
connID := generateConnectionID()
|
connID := generateConnectionID()
|
||||||
@@ -76,9 +85,10 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
// Use the worker pool for controlled concurrency instead of direct generation
|
// Use the worker pool for controlled concurrency instead of direct generation
|
||||||
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
nlog.Error().Err(err).Msg("initial map generation failed")
|
||||||
nodeConn.removeConnectionByChannel(c)
|
nodeConn.removeConnectionByChannel(c)
|
||||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
|
||||||
|
return fmt.Errorf("generating initial map for node %d: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a blocking send with timeout for initial map since the channel should be ready
|
// Use a blocking send with timeout for initial map since the channel should be ready
|
||||||
@@ -86,12 +96,13 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
select {
|
select {
|
||||||
case c <- initialMap:
|
case c <- initialMap:
|
||||||
// Success
|
// Success
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second): //nolint:mnd
|
||||||
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
|
||||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||||
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||||
nodeConn.removeConnectionByChannel(c)
|
nodeConn.removeConnectionByChannel(c)
|
||||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
|
||||||
|
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update connection status
|
// Update connection status
|
||||||
@@ -100,9 +111,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
// Node will automatically receive updates through the normal flow
|
// Node will automatically receive updates through the normal flow
|
||||||
// The initial full map already contains all current state
|
// The initial full map already contains all current state
|
||||||
|
|
||||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)).
|
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
|
||||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||||
Msg("Node connection established in batcher because AddNode completed successfully")
|
Msg("node connection established in batcher")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -112,31 +123,34 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||||||
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
||||||
// Reports if the node still has active connections after removal.
|
// Reports if the node still has active connections after removal.
|
||||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||||
|
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
||||||
|
|
||||||
nodeConn, exists := b.nodes.Load(id)
|
nodeConn, exists := b.nodes.Load(id)
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher")
|
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove specific connection
|
// Remove specific connection
|
||||||
removed := nodeConn.removeConnectionByChannel(c)
|
removed := nodeConn.removeConnectionByChannel(c)
|
||||||
if !removed {
|
if !removed {
|
||||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid")
|
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if node has any remaining active connections
|
// Check if node has any remaining active connections
|
||||||
if nodeConn.hasActiveConnections() {
|
if nodeConn.hasActiveConnections() {
|
||||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
nlog.Debug().Caller().
|
||||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||||
Msg("Node connection removed but keeping online because other connections remain")
|
Msg("node connection removed but keeping online, other connections remain")
|
||||||
|
|
||||||
return true // Node still has active connections
|
return true // Node still has active connections
|
||||||
}
|
}
|
||||||
|
|
||||||
// No active connections - keep the node entry alive for rapid reconnections
|
// No active connections - keep the node entry alive for rapid reconnections
|
||||||
// The node will get a fresh full map when it reconnects
|
// The node will get a fresh full map when it reconnects
|
||||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
|
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
|
||||||
b.connected.Store(id, ptr.To(time.Now()))
|
b.connected.Store(id, new(time.Now()))
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -196,11 +210,13 @@ func (b *LockFreeBatcher) doWork() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *LockFreeBatcher) worker(workerID int) {
|
func (b *LockFreeBatcher) worker(workerID int) {
|
||||||
|
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case w, ok := <-b.workCh:
|
case w, ok := <-b.workCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID)
|
wlog.Debug().Msg("worker channel closing, shutting down")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,29 +228,29 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
// This is used for synchronous map generation.
|
// This is used for synchronous map generation.
|
||||||
if w.resultCh != nil {
|
if w.resultCh != nil {
|
||||||
var result workResult
|
var result workResult
|
||||||
|
|
||||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||||
|
|
||||||
result.err = err
|
result.err = err
|
||||||
if result.err != nil {
|
if result.err != nil {
|
||||||
b.workErrors.Add(1)
|
b.workErrors.Add(1)
|
||||||
log.Error().Err(result.err).
|
wlog.Error().Err(result.err).
|
||||||
Int("worker.id", workerID).
|
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
Str(zf.Reason, w.c.Reason).
|
||||||
Str("reason", w.c.Reason).
|
|
||||||
Msg("failed to generate map response for synchronous work")
|
Msg("failed to generate map response for synchronous work")
|
||||||
} else if result.mapResponse != nil {
|
} else if result.mapResponse != nil {
|
||||||
// Update peer tracking for synchronous responses too
|
// Update peer tracking for synchronous responses too
|
||||||
nc.updateSentPeers(result.mapResponse)
|
nc.updateSentPeers(result.mapResponse)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
|
||||||
|
|
||||||
b.workErrors.Add(1)
|
b.workErrors.Add(1)
|
||||||
log.Error().Err(result.err).
|
wlog.Error().Err(result.err).
|
||||||
Int("worker.id", workerID).
|
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
|
||||||
Msg("node not found for synchronous work")
|
Msg("node not found for synchronous work")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,15 +273,14 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||||||
err := nc.change(w.c)
|
err := nc.change(w.c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.workErrors.Add(1)
|
b.workErrors.Add(1)
|
||||||
log.Error().Err(err).
|
wlog.Error().Err(err).
|
||||||
Int("worker.id", workerID).
|
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||||
Uint64("node.id", w.nodeID.Uint64()).
|
Str(zf.Reason, w.c.Reason).
|
||||||
Str("reason", w.c.Reason).
|
|
||||||
Msg("failed to apply change")
|
Msg("failed to apply change")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case <-b.done:
|
case <-b.done:
|
||||||
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker")
|
wlog.Debug().Msg("batcher shutting down, exiting worker")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -310,8 +325,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
|
|||||||
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
|
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
|
||||||
b.totalNodes.Add(-1)
|
b.totalNodes.Add(-1)
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Uint64("node.id", removedID.Uint64()).
|
Uint64(zf.NodeID, removedID.Uint64()).
|
||||||
Msg("Removed deleted node from batcher")
|
Msg("removed deleted node from batcher")
|
||||||
}
|
}
|
||||||
|
|
||||||
b.connected.Delete(removedID)
|
b.connected.Delete(removedID)
|
||||||
@@ -398,14 +413,15 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
// Clean up the identified nodes
|
// Clean up the identified nodes
|
||||||
for _, nodeID := range nodesToCleanup {
|
for _, nodeID := range nodesToCleanup {
|
||||||
log.Info().Uint64("node.id", nodeID.Uint64()).
|
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
|
||||||
Dur("offline_duration", cleanupThreshold).
|
Dur("offline_duration", cleanupThreshold).
|
||||||
Msg("Cleaning up node that has been offline for too long")
|
Msg("cleaning up node that has been offline for too long")
|
||||||
|
|
||||||
b.nodes.Delete(nodeID)
|
b.nodes.Delete(nodeID)
|
||||||
b.connected.Delete(nodeID)
|
b.connected.Delete(nodeID)
|
||||||
@@ -413,8 +429,8 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(nodesToCleanup) > 0 {
|
if len(nodesToCleanup) > 0 {
|
||||||
log.Info().Int("cleaned_nodes", len(nodesToCleanup)).
|
log.Info().Int(zf.CleanedNodes, len(nodesToCleanup)).
|
||||||
Msg("Completed cleanup of long-offline nodes")
|
Msg("completed cleanup of long-offline nodes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -450,6 +466,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|||||||
if nodeConn.hasActiveConnections() {
|
if nodeConn.hasActiveConnections() {
|
||||||
ret.Store(id, true)
|
ret.Store(id, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -465,6 +482,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|||||||
ret.Store(id, false)
|
ret.Store(id, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -484,7 +502,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
|
|||||||
case result := <-resultCh:
|
case result := <-resultCh:
|
||||||
return result.mapResponse, result.err
|
return result.mapResponse, result.err
|
||||||
case <-b.done:
|
case <-b.done:
|
||||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,6 +520,7 @@ type connectionEntry struct {
|
|||||||
type multiChannelNodeConn struct {
|
type multiChannelNodeConn struct {
|
||||||
id types.NodeID
|
id types.NodeID
|
||||||
mapper *mapper
|
mapper *mapper
|
||||||
|
log zerolog.Logger
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
connections []*connectionEntry
|
connections []*connectionEntry
|
||||||
@@ -518,8 +537,9 @@ type multiChannelNodeConn struct {
|
|||||||
// generateConnectionID generates a unique connection identifier.
|
// generateConnectionID generates a unique connection identifier.
|
||||||
func generateConnectionID() string {
|
func generateConnectionID() string {
|
||||||
bytes := make([]byte, 8)
|
bytes := make([]byte, 8)
|
||||||
rand.Read(bytes)
|
_, _ = rand.Read(bytes)
|
||||||
return fmt.Sprintf("%x", bytes)
|
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||||
@@ -528,6 +548,7 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
|
|||||||
id: id,
|
id: id,
|
||||||
mapper: mapper,
|
mapper: mapper,
|
||||||
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
||||||
|
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -546,18 +567,21 @@ func (mc *multiChannelNodeConn) close() {
|
|||||||
// addConnection adds a new connection.
|
// addConnection adds a new connection.
|
||||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||||
mutexWaitStart := time.Now()
|
mutexWaitStart := time.Now()
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
|
||||||
|
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
|
||||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||||
|
|
||||||
mc.mutex.Lock()
|
mc.mutex.Lock()
|
||||||
|
|
||||||
mutexWaitDur := time.Since(mutexWaitStart)
|
mutexWaitDur := time.Since(mutexWaitStart)
|
||||||
|
|
||||||
defer mc.mutex.Unlock()
|
defer mc.mutex.Unlock()
|
||||||
|
|
||||||
mc.connections = append(mc.connections, entry)
|
mc.connections = append(mc.connections, entry)
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
|
||||||
Int("total_connections", len(mc.connections)).
|
Int("total_connections", len(mc.connections)).
|
||||||
Dur("mutex_wait_time", mutexWaitDur).
|
Dur("mutex_wait_time", mutexWaitDur).
|
||||||
Msg("Successfully added connection after mutex wait")
|
Msg("successfully added connection after mutex wait")
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeConnectionByChannel removes a connection by matching channel pointer.
|
// removeConnectionByChannel removes a connection by matching channel pointer.
|
||||||
@@ -569,12 +593,14 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
|
|||||||
if entry.c == c {
|
if entry.c == c {
|
||||||
// Remove this connection
|
// Remove this connection
|
||||||
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
|
||||||
Int("remaining_connections", len(mc.connections)).
|
Int("remaining_connections", len(mc.connections)).
|
||||||
Msg("Successfully removed connection")
|
Msg("successfully removed connection")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -606,36 +632,41 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||||||
if len(mc.connections) == 0 {
|
if len(mc.connections) == 0 {
|
||||||
// During rapid reconnection, nodes may temporarily have no active connections
|
// During rapid reconnection, nodes may temporarily have no active connections
|
||||||
// This is not an error - the node will receive a full map when it reconnects
|
// This is not an error - the node will receive a full map when it reconnects
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
mc.log.Debug().Caller().
|
||||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||||
|
|
||||||
return nil // Return success instead of error
|
return nil // Return success instead of error
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
mc.log.Debug().Caller().
|
||||||
Int("total_connections", len(mc.connections)).
|
Int("total_connections", len(mc.connections)).
|
||||||
Msg("send: broadcasting to all connections")
|
Msg("send: broadcasting to all connections")
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
successCount := 0
|
successCount := 0
|
||||||
|
|
||||||
var failedConnections []int // Track failed connections for removal
|
var failedConnections []int // Track failed connections for removal
|
||||||
|
|
||||||
// Send to all connections
|
// Send to all connections
|
||||||
for i, conn := range mc.connections {
|
for i, conn := range mc.connections {
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||||
Str("conn.id", conn.id).Int("connection_index", i).
|
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||||
Msg("send: attempting to send to connection")
|
Msg("send: attempting to send to connection")
|
||||||
|
|
||||||
if err := conn.send(data); err != nil {
|
err := conn.send(data)
|
||||||
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
|
|
||||||
failedConnections = append(failedConnections, i)
|
failedConnections = append(failedConnections, i)
|
||||||
log.Warn().Err(err).
|
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||||
Str("conn.id", conn.id).Int("connection_index", i).
|
|
||||||
Msg("send: connection send failed")
|
Msg("send: connection send failed")
|
||||||
} else {
|
} else {
|
||||||
successCount++
|
successCount++
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
|
||||||
Str("conn.id", conn.id).Int("connection_index", i).
|
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||||
|
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||||
Msg("send: successfully sent to connection")
|
Msg("send: successfully sent to connection")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -643,15 +674,15 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||||||
// Remove failed connections (in reverse order to maintain indices)
|
// Remove failed connections (in reverse order to maintain indices)
|
||||||
for i := len(failedConnections) - 1; i >= 0; i-- {
|
for i := len(failedConnections) - 1; i >= 0; i-- {
|
||||||
idx := failedConnections[i]
|
idx := failedConnections[i]
|
||||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
mc.log.Debug().Caller().
|
||||||
Str("conn.id", mc.connections[idx].id).
|
Str(zf.ConnID, mc.connections[idx].id).
|
||||||
Msg("send: removing failed connection")
|
Msg("send: removing failed connection")
|
||||||
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
|
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
|
||||||
}
|
}
|
||||||
|
|
||||||
mc.updateCount.Add(1)
|
mc.updateCount.Add(1)
|
||||||
|
|
||||||
log.Debug().Uint64("node.id", mc.id.Uint64()).
|
mc.log.Debug().
|
||||||
Int("successful_sends", successCount).
|
Int("successful_sends", successCount).
|
||||||
Int("failed_connections", len(failedConnections)).
|
Int("failed_connections", len(failedConnections)).
|
||||||
Int("remaining_connections", len(mc.connections)).
|
Int("remaining_connections", len(mc.connections)).
|
||||||
@@ -688,7 +719,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
|||||||
case <-time.After(50 * time.Millisecond):
|
case <-time.After(50 * time.Millisecond):
|
||||||
// Connection is likely stale - client isn't reading from channel
|
// Connection is likely stale - client isn't reading from channel
|
||||||
// This catches the case where Docker containers are killed but channels remain open
|
// This catches the case where Docker containers are killed but channels remain open
|
||||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -798,6 +829,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|||||||
Connected: connected,
|
Connected: connected,
|
||||||
ActiveConnections: activeConnCount,
|
ActiveConnections: activeConnCount,
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -812,6 +844,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|||||||
ActiveConnections: 0,
|
ActiveConnections: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type batcherTestCase struct {
|
|||||||
// that would normally be sent by poll.go in production.
|
// that would normally be sent by poll.go in production.
|
||||||
type testBatcherWrapper struct {
|
type testBatcherWrapper struct {
|
||||||
Batcher
|
Batcher
|
||||||
|
|
||||||
state *state.State
|
state *state.State
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finally remove from the real batcher
|
// Finally remove from the real batcher
|
||||||
removed := t.Batcher.RemoveNode(id, c)
|
return t.Batcher.RemoveNode(id, c)
|
||||||
if !removed {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||||
@@ -129,8 +125,6 @@ const (
|
|||||||
SMALL_BUFFER_SIZE = 3
|
SMALL_BUFFER_SIZE = 3
|
||||||
TINY_BUFFER_SIZE = 1 // For maximum contention
|
TINY_BUFFER_SIZE = 1 // For maximum contention
|
||||||
LARGE_BUFFER_SIZE = 200
|
LARGE_BUFFER_SIZE = 200
|
||||||
|
|
||||||
reservedResponseHeaderSize = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestData contains all test entities created for a test scenario.
|
// TestData contains all test entities created for a test scenario.
|
||||||
@@ -241,8 +235,8 @@ func setupBatcherWithTestData(
|
|||||||
}
|
}
|
||||||
|
|
||||||
derpMap, err := derp.GetDERPMap(cfg.DERP)
|
derpMap, err := derp.GetDERPMap(cfg.DERP)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, derpMap)
|
require.NotNil(t, derpMap)
|
||||||
|
|
||||||
state.SetDERPMap(derpMap)
|
state.SetDERPMap(derpMap)
|
||||||
|
|
||||||
@@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getStats returns a copy of the statistics for a node.
|
// getStats returns a copy of the statistics for a node.
|
||||||
|
//
|
||||||
|
//nolint:unused
|
||||||
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
|
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
|
||||||
ut.mu.RLock()
|
ut.mu.RLock()
|
||||||
defer ut.mu.RUnlock()
|
defer ut.mu.RUnlock()
|
||||||
@@ -386,16 +382,14 @@ type UpdateInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseUpdateAndAnalyze parses an update and returns detailed information.
|
// parseUpdateAndAnalyze parses an update and returns detailed information.
|
||||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
|
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
|
||||||
info := UpdateInfo{
|
return UpdateInfo{
|
||||||
PeerCount: len(resp.Peers),
|
PeerCount: len(resp.Peers),
|
||||||
PatchCount: len(resp.PeersChangedPatch),
|
PatchCount: len(resp.PeersChangedPatch),
|
||||||
IsFull: len(resp.Peers) > 0,
|
IsFull: len(resp.Peers) > 0,
|
||||||
IsPatch: len(resp.PeersChangedPatch) > 0,
|
IsPatch: len(resp.PeersChangedPatch) > 0,
|
||||||
IsDERP: resp.DERPMap != nil,
|
IsDERP: resp.DERPMap != nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
return info, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// start begins consuming updates from the node's channel and tracking stats.
|
// start begins consuming updates from the node's channel and tracking stats.
|
||||||
@@ -417,7 +411,8 @@ func (n *node) start() {
|
|||||||
atomic.AddInt64(&n.updateCount, 1)
|
atomic.AddInt64(&n.updateCount, 1)
|
||||||
|
|
||||||
// Parse update and track detailed stats
|
// Parse update and track detailed stats
|
||||||
if info, err := parseUpdateAndAnalyze(data); err == nil {
|
info := parseUpdateAndAnalyze(data)
|
||||||
|
{
|
||||||
// Track update types
|
// Track update types
|
||||||
if info.IsFull {
|
if info.IsFull {
|
||||||
atomic.AddInt64(&n.fullCount, 1)
|
atomic.AddInt64(&n.fullCount, 1)
|
||||||
@@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
|||||||
testNode.start()
|
testNode.start()
|
||||||
|
|
||||||
// Connect the node to the batcher
|
// Connect the node to the batcher
|
||||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Wait for connection to be established
|
// Wait for connection to be established
|
||||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
@@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
|
|
||||||
for i := range allNodes {
|
for i := range allNodes {
|
||||||
node := &allNodes[i]
|
node := &allNodes[i]
|
||||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Issue full update after each join to ensure connectivity
|
// Issue full update after each join to ensure connectivity
|
||||||
batcher.AddWork(change.FullUpdate())
|
batcher.AddWork(change.FullUpdate())
|
||||||
@@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
|
|
||||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
connectedCount := 0
|
connectedCount := 0
|
||||||
|
|
||||||
for i := range allNodes {
|
for i := range allNodes {
|
||||||
node := &allNodes[i]
|
node := &allNodes[i]
|
||||||
|
|
||||||
@@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||||||
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
||||||
|
|
||||||
t.Logf("✅ All nodes achieved full connectivity!")
|
t.Logf("✅ All nodes achieved full connectivity!")
|
||||||
|
|
||||||
totalTime := time.Since(startTime)
|
totalTime := time.Since(startTime)
|
||||||
|
|
||||||
// Disconnect all nodes
|
// Disconnect all nodes
|
||||||
@@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
batcher := testData.Batcher
|
batcher := testData.Batcher
|
||||||
tn := testData.Nodes[0]
|
tn := &testData.Nodes[0]
|
||||||
tn2 := testData.Nodes[1]
|
tn2 := &testData.Nodes[1]
|
||||||
|
|
||||||
// Test AddNode with real node ID
|
// Test AddNode with real node ID
|
||||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||||
|
|
||||||
if !batcher.IsConnected(tn.n.ID) {
|
if !batcher.IsConnected(tn.n.ID) {
|
||||||
t.Error("Node should be connected after AddNode")
|
t.Error("Node should be connected after AddNode")
|
||||||
@@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Drain any initial messages from first node
|
// Drain any initial messages from first node
|
||||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
drainChannelTimeout(tn.ch, 100*time.Millisecond)
|
||||||
|
|
||||||
// Add the second node and verify update message
|
// Add the second node and verify update message
|
||||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||||
|
|
||||||
// First node should get an update that second node has connected.
|
// First node should get an update that second node has connected.
|
||||||
@@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) {
|
||||||
count := 0
|
|
||||||
|
|
||||||
timer := time.NewTimer(timeout)
|
timer := time.NewTimer(timeout)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case data := <-ch:
|
case <-ch:
|
||||||
count++
|
// Drain message
|
||||||
// Optional: add debug output if needed
|
|
||||||
_ = data
|
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
|||||||
testNodes := testData.Nodes
|
testNodes := testData.Nodes
|
||||||
|
|
||||||
ch := make(chan *tailcfg.MapResponse, 10)
|
ch := make(chan *tailcfg.MapResponse, 10)
|
||||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Track update content for validation
|
// Track update content for validation
|
||||||
var receivedUpdates []*tailcfg.MapResponse
|
var receivedUpdates []*tailcfg.MapResponse
|
||||||
@@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
|||||||
// even when real node updates are being processed, ensuring no race conditions
|
// even when real node updates are being processed, ensuring no race conditions
|
||||||
// occur during channel replacement with actual workload.
|
// occur during channel replacement with actual workload.
|
||||||
func XTestBatcherChannelClosingRace(t *testing.T) {
|
func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
for _, batcherFunc := range allBatcherFunctions {
|
for _, batcherFunc := range allBatcherFunctions {
|
||||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||||
// Create test environment with real database and nodes
|
// Create test environment with real database and nodes
|
||||||
@@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
batcher := testData.Batcher
|
batcher := testData.Batcher
|
||||||
testNode := testData.Nodes[0]
|
testNode := &testData.Nodes[0]
|
||||||
|
|
||||||
var (
|
var (
|
||||||
channelIssues int
|
channelIssues int
|
||||||
@@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||||
|
|
||||||
wg.Go(func() {
|
wg.Go(func() {
|
||||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add real work during connection chaos
|
// Add real work during connection chaos
|
||||||
@@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||||||
|
|
||||||
wg.Go(func() {
|
wg.Go(func() {
|
||||||
runtime.Gosched() // Yield to introduce timing variability
|
runtime.Gosched() // Yield to introduce timing variability
|
||||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
|
||||||
|
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Remove second connection
|
// Remove second connection
|
||||||
@@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
batcher := testData.Batcher
|
batcher := testData.Batcher
|
||||||
testNode := testData.Nodes[0]
|
testNode := &testData.Nodes[0]
|
||||||
|
|
||||||
var (
|
var (
|
||||||
panics int
|
panics int
|
||||||
@@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
ch := make(chan *tailcfg.MapResponse, 5)
|
ch := make(chan *tailcfg.MapResponse, 5)
|
||||||
|
|
||||||
// Add node and immediately queue real work
|
// Add node and immediately queue real work
|
||||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
batcher.AddWork(change.DERPMap())
|
batcher.AddWork(change.DERPMap())
|
||||||
|
|
||||||
// Consumer goroutine to validate data and detect channel issues
|
// Consumer goroutine to validate data and detect channel issues
|
||||||
@@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
for range i % 3 {
|
for range i % 3 {
|
||||||
runtime.Gosched() // Introduce timing variability
|
runtime.Gosched() // Introduce timing variability
|
||||||
}
|
}
|
||||||
|
|
||||||
batcher.RemoveNode(testNode.n.ID, ch)
|
batcher.RemoveNode(testNode.n.ID, ch)
|
||||||
|
|
||||||
// Yield to allow workers to process and close channels
|
// Yield to allow workers to process and close channels
|
||||||
@@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||||||
// real node data. The test validates that stable clients continue to function
|
// real node data. The test validates that stable clients continue to function
|
||||||
// normally and receive proper updates despite the connection churn from other clients,
|
// normally and receive proper updates despite the connection churn from other clients,
|
||||||
// ensuring system stability under concurrent load.
|
// ensuring system stability under concurrent load.
|
||||||
|
//
|
||||||
|
//nolint:gocyclo // complex concurrent test scenario
|
||||||
func TestBatcherConcurrentClients(t *testing.T) {
|
func TestBatcherConcurrentClients(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("Skipping concurrent client test in short mode")
|
t.Skip("Skipping concurrent client test in short mode")
|
||||||
@@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
|
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
|
||||||
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||||
|
|
||||||
for _, node := range stableNodes {
|
for i := range stableNodes {
|
||||||
|
node := &stableNodes[i]
|
||||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||||
stableChannels[node.n.ID] = ch
|
stableChannels[node.n.ID] = ch
|
||||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Monitor updates for each stable client
|
// Monitor updates for each stable client
|
||||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||||
@@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
// Channel was closed, exit gracefully
|
// Channel was closed, exit gracefully
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if valid, reason := validateUpdateContent(data); valid {
|
if valid, reason := validateUpdateContent(data); valid {
|
||||||
tracker.recordUpdate(
|
tracker.recordUpdate(
|
||||||
nodeID,
|
nodeID,
|
||||||
@@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
|
|
||||||
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety
|
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety
|
||||||
for i := range numCycles {
|
for i := range numCycles {
|
||||||
for _, node := range churningNodes {
|
for j := range churningNodes {
|
||||||
|
node := &churningNodes[j]
|
||||||
|
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
|
||||||
// Connect churning node
|
// Connect churning node
|
||||||
@@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||||
|
|
||||||
churningChannelsMutex.Lock()
|
churningChannelsMutex.Lock()
|
||||||
|
|
||||||
churningChannels[nodeID] = ch
|
churningChannels[nodeID] = ch
|
||||||
|
|
||||||
churningChannelsMutex.Unlock()
|
churningChannelsMutex.Unlock()
|
||||||
|
|
||||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
// Consume updates to prevent blocking
|
// Consume updates to prevent blocking
|
||||||
go func() {
|
go func() {
|
||||||
@@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
// Channel was closed, exit gracefully
|
// Channel was closed, exit gracefully
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if valid, _ := validateUpdateContent(data); valid {
|
if valid, _ := validateUpdateContent(data); valid {
|
||||||
tracker.recordUpdate(
|
tracker.recordUpdate(
|
||||||
nodeID,
|
nodeID,
|
||||||
@@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
for range i % 5 {
|
for range i % 5 {
|
||||||
runtime.Gosched() // Introduce timing variability
|
runtime.Gosched() // Introduce timing variability
|
||||||
}
|
}
|
||||||
|
|
||||||
churningChannelsMutex.Lock()
|
churningChannelsMutex.Lock()
|
||||||
|
|
||||||
ch, exists := churningChannels[nodeID]
|
ch, exists := churningChannels[nodeID]
|
||||||
@@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
|
|
||||||
if i%7 == 0 && len(allNodes) > 0 {
|
if i%7 == 0 && len(allNodes) > 0 {
|
||||||
// Node-specific changes using real nodes
|
// Node-specific changes using real nodes
|
||||||
node := allNodes[i%len(allNodes)]
|
node := &allNodes[i%len(allNodes)]
|
||||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||||
testExpiry := time.Now().Add(24 * time.Hour)
|
testExpiry := time.Now().Add(24 * time.Hour)
|
||||||
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
|
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
|
||||||
@@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
|
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
|
||||||
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
|
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
|
||||||
|
|
||||||
for _, node := range stableNodes {
|
for i := range stableNodes {
|
||||||
|
node := &stableNodes[i]
|
||||||
if stats, exists := allStats[node.n.ID]; exists {
|
if stats, exists := allStats[node.n.ID]; exists {
|
||||||
stableUpdateCount += stats.TotalUpdates
|
stableUpdateCount += stats.TotalUpdates
|
||||||
t.Logf("Stable node %d: %d updates",
|
t.Logf("Stable node %d: %d updates",
|
||||||
@@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range churningNodes {
|
for i := range churningNodes {
|
||||||
|
node := &churningNodes[i]
|
||||||
if stats, exists := allStats[node.n.ID]; exists {
|
if stats, exists := allStats[node.n.ID]; exists {
|
||||||
churningUpdateCount += stats.TotalUpdates
|
churningUpdateCount += stats.TotalUpdates
|
||||||
}
|
}
|
||||||
@@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify all stable clients are still functional
|
// Verify all stable clients are still functional
|
||||||
for _, node := range stableNodes {
|
for i := range stableNodes {
|
||||||
|
node := &stableNodes[i]
|
||||||
if !batcher.IsConnected(node.n.ID) {
|
if !batcher.IsConnected(node.n.ID) {
|
||||||
t.Errorf("Stable node %d lost connection during racing", node.n.ID)
|
t.Errorf("Stable node %d lost connection during racing", node.n.ID)
|
||||||
}
|
}
|
||||||
@@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||||||
// It validates that the system remains stable with no deadlocks, panics, or
|
// It validates that the system remains stable with no deadlocks, panics, or
|
||||||
// missed updates under sustained high load. The test uses real node data to
|
// missed updates under sustained high load. The test uses real node data to
|
||||||
// generate authentic update scenarios and tracks comprehensive statistics.
|
// generate authentic update scenarios and tracks comprehensive statistics.
|
||||||
|
//
|
||||||
|
//nolint:gocyclo,thelper // complex scalability test scenario
|
||||||
func XTestBatcherScalability(t *testing.T) {
|
func XTestBatcherScalability(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("Skipping scalability test in short mode")
|
t.Skip("Skipping scalability test in short mode")
|
||||||
@@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
description string
|
description string
|
||||||
}
|
}
|
||||||
|
|
||||||
var testCases []testCase
|
testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes))
|
||||||
|
|
||||||
// Generate all combinations of the test matrix
|
// Generate all combinations of the test matrix
|
||||||
for _, nodeCount := range nodes {
|
for _, nodeCount := range nodes {
|
||||||
@@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
|
|
||||||
for i := range testNodes {
|
for i := range testNodes {
|
||||||
node := &testNodes[i]
|
node := &testNodes[i]
|
||||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
|
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
|
|
||||||
connectedNodes[node.n.ID] = true
|
connectedNodes[node.n.ID] = true
|
||||||
@@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connection/disconnection cycles for subset of nodes
|
// Connection/disconnection cycles for subset of nodes
|
||||||
for i, node := range chaosNodes {
|
for i := range chaosNodes {
|
||||||
|
node := &chaosNodes[i]
|
||||||
// Only add work if this is connection chaos or mixed
|
// Only add work if this is connection chaos or mixed
|
||||||
if tc.chaosType == "connection" || tc.chaosType == "mixed" {
|
if tc.chaosType == "connection" || tc.chaosType == "mixed" {
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
@@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||||||
channel,
|
channel,
|
||||||
tailcfg.CapabilityVersion(100),
|
tailcfg.CapabilityVersion(100),
|
||||||
)
|
)
|
||||||
|
|
||||||
connectedNodesMutex.Lock()
|
connectedNodesMutex.Lock()
|
||||||
|
|
||||||
connectedNodes[nodeID] = true
|
connectedNodes[nodeID] = true
|
||||||
@@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
t.Logf("Created %d nodes in database", len(allNodes))
|
t.Logf("Created %d nodes in database", len(allNodes))
|
||||||
|
|
||||||
// Connect nodes one at a time and wait for each to be connected
|
// Connect nodes one at a time and wait for each to be connected
|
||||||
for i, node := range allNodes {
|
for i := range allNodes {
|
||||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
node := &allNodes[i]
|
||||||
|
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||||
|
|
||||||
// Wait for node to be connected
|
// Wait for node to be connected
|
||||||
@@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||||||
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
|
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
|
||||||
|
|
||||||
// Check how many peers each node should see
|
// Check how many peers each node should see
|
||||||
for i, node := range allNodes {
|
for i := range allNodes {
|
||||||
|
node := &allNodes[i]
|
||||||
peers := testData.State.ListPeers(node.n.ID)
|
peers := testData.State.ListPeers(node.n.ID)
|
||||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||||
}
|
}
|
||||||
@@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
|
|
||||||
// Phase 1: Connect all nodes initially
|
// Phase 1: Connect all nodes initially
|
||||||
t.Logf("Phase 1: Connecting all nodes...")
|
t.Logf("Phase 1: Connecting all nodes...")
|
||||||
for i, node := range allNodes {
|
|
||||||
|
for i := range allNodes {
|
||||||
|
node := &allNodes[i]
|
||||||
|
|
||||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||||
@@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
|
|
||||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||||
for i, node := range allNodes {
|
|
||||||
|
for i := range allNodes {
|
||||||
|
node := &allNodes[i]
|
||||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||||
|
|
||||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||||
for i, node := range allNodes {
|
for i := range allNodes {
|
||||||
|
node := &allNodes[i]
|
||||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||||
|
|
||||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||||
@@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
debugInfo := debugBatcher.Debug()
|
debugInfo := debugBatcher.Debug()
|
||||||
disconnectedCount := 0
|
disconnectedCount := 0
|
||||||
|
|
||||||
for i, node := range allNodes {
|
for i := range allNodes {
|
||||||
|
node := &allNodes[i]
|
||||||
if info, exists := debugInfo[node.n.ID]; exists {
|
if info, exists := debugInfo[node.n.ID]; exists {
|
||||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||||
|
|
||||||
@@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
if infoMap, ok := info.(map[string]any); ok {
|
if infoMap, ok := info.(map[string]any); ok {
|
||||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||||
disconnectedCount++
|
disconnectedCount++
|
||||||
|
|
||||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
disconnectedCount++
|
disconnectedCount++
|
||||||
|
|
||||||
t.Logf("Node %d missing from debug info entirely", i)
|
t.Logf("Node %d missing from debug info entirely", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
case update := <-newChannels[i]:
|
case update := <-newChannels[i]:
|
||||||
if update != nil {
|
if update != nil {
|
||||||
receivedCount++
|
receivedCount++
|
||||||
|
|
||||||
t.Logf("Node %d received update successfully", i)
|
t.Logf("Node %d received update successfully", i)
|
||||||
}
|
}
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
@@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:gocyclo // complex multi-connection test scenario
|
||||||
func TestBatcherMultiConnection(t *testing.T) {
|
func TestBatcherMultiConnection(t *testing.T) {
|
||||||
for _, batcherFunc := range allBatcherFunctions {
|
for _, batcherFunc := range allBatcherFunctions {
|
||||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||||
@@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
batcher := testData.Batcher
|
batcher := testData.Batcher
|
||||||
node1 := testData.Nodes[0]
|
node1 := &testData.Nodes[0]
|
||||||
node2 := testData.Nodes[1]
|
node2 := &testData.Nodes[1]
|
||||||
|
|
||||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||||
|
|
||||||
// Phase 1: Connect first node with initial connection
|
// Phase 1: Connect first node with initial connection
|
||||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||||
|
|
||||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add node1: %v", err)
|
t.Fatalf("Failed to add node1: %v", err)
|
||||||
@@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||||||
|
|
||||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||||
|
|
||||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||||
|
|
||||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||||
@@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||||||
|
|
||||||
// Phase 3: Add third connection for node1
|
// Phase 3: Add third connection for node1
|
||||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||||
|
|
||||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||||
|
|
||||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||||
@@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||||||
|
|
||||||
// Phase 4: Verify debug status shows correct connection count
|
// Phase 4: Verify debug status shows correct connection count
|
||||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||||
|
|
||||||
if debugBatcher, ok := batcher.(interface {
|
if debugBatcher, ok := batcher.(interface {
|
||||||
Debug() map[types.NodeID]any
|
Debug() map[types.NodeID]any
|
||||||
}); ok {
|
}); ok {
|
||||||
@@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||||||
|
|
||||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||||
t.Logf("Node1 debug info: %+v", info)
|
t.Logf("Node1 debug info: %+v", info)
|
||||||
|
|
||||||
if infoMap, ok := info.(map[string]any); ok {
|
if infoMap, ok := info.(map[string]any); ok {
|
||||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||||
if activeConnections != 3 {
|
if activeConnections != 3 {
|
||||||
@@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package mapper
|
package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
@@ -36,6 +35,7 @@ const (
|
|||||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
return &MapResponseBuilder{
|
return &MapResponseBuilder{
|
||||||
resp: &tailcfg.MapResponse{
|
resp: &tailcfg.MapResponse{
|
||||||
KeepAlive: false,
|
KeepAlive: false,
|
||||||
@@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
|
|||||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||||
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if !ok {
|
if !ok {
|
||||||
b.addError(errors.New("node not found"))
|
b.addError(ErrNodeNotFoundMapper)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||||||
b.resp.Debug = &tailcfg.Debug{
|
b.resp.Debug = &tailcfg.Debug{
|
||||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if !ok {
|
if !ok {
|
||||||
b.addError(errors.New("node not found"))
|
b.addError(ErrNodeNotFoundMapper)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
|||||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if !ok {
|
if !ok {
|
||||||
b.addError(errors.New("node not found"))
|
b.addError(ErrNodeNotFoundMapper)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
|||||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if !ok {
|
if !ok {
|
||||||
b.addError(errors.New("node not found"))
|
b.addError(ErrNodeNotFoundMapper)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
|
|||||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if !ok {
|
if !ok {
|
||||||
b.addError(errors.New("node not found"))
|
b.addError(ErrNodeNotFoundMapper)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
|||||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("node not found")
|
return nil, ErrNodeNotFoundMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get unreduced matchers for peer relationship determination.
|
// Get unreduced matchers for peer relationship determination.
|
||||||
@@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
|
|||||||
|
|
||||||
// WithPeersRemoved adds removed peer IDs.
|
// WithPeersRemoved adds removed peer IDs.
|
||||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||||
var tailscaleIDs []tailcfg.NodeID
|
tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs))
|
||||||
for _, id := range removedIDs {
|
for _, id := range removedIDs {
|
||||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||||
}
|
}
|
||||||
|
|
||||||
b.resp.PeersRemoved = tailscaleIDs
|
b.resp.PeersRemoved = tailscaleIDs
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build finalizes the response and returns marshaled bytes
|
// Build finalizes the response and returns marshaled bytes.
|
||||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||||
if len(b.errs) > 0 {
|
if len(b.errs) > 0 {
|
||||||
return nil, multierr.New(b.errs...)
|
return nil, multierr.New(b.errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if debugDumpMapResponsePath != "" {
|
if debugDumpMapResponsePath != "" {
|
||||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
|||||||
|
|
||||||
// Build should return a multierr
|
// Build should return a multierr
|
||||||
data, err := result.Build()
|
data, err := result.Build()
|
||||||
assert.Nil(t, data)
|
require.Nil(t, data)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
// The error should contain information about multiple errors
|
// The error should contain information about multiple errors
|
||||||
assert.Contains(t, err.Error(), "multiple errors")
|
assert.Contains(t, err.Error(), "multiple errors")
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||||
mapperIDLength = 8
|
|
||||||
debugMapResponsePerm = 0o755
|
debugMapResponsePerm = 0o755
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,6 +49,7 @@ type mapper struct {
|
|||||||
created time.Time
|
created time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:unused
|
||||||
type patch struct {
|
type patch struct {
|
||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
change *tailcfg.PeerChange
|
change *tailcfg.PeerChange
|
||||||
@@ -60,7 +60,6 @@ func newMapper(
|
|||||||
state *state.State,
|
state *state.State,
|
||||||
) *mapper {
|
) *mapper {
|
||||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||||
|
|
||||||
return &mapper{
|
return &mapper{
|
||||||
state: state,
|
state: state,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@@ -76,23 +75,26 @@ func generateUserProfiles(
|
|||||||
) []tailcfg.UserProfile {
|
) []tailcfg.UserProfile {
|
||||||
userMap := make(map[uint]*types.UserView)
|
userMap := make(map[uint]*types.UserView)
|
||||||
ids := make([]uint, 0, len(userMap))
|
ids := make([]uint, 0, len(userMap))
|
||||||
|
|
||||||
user := node.Owner()
|
user := node.Owner()
|
||||||
if !user.Valid() {
|
if !user.Valid() {
|
||||||
log.Error().
|
log.Error().
|
||||||
Uint64("node.id", node.ID().Uint64()).
|
EmbedObject(node).
|
||||||
Str("node.name", node.Hostname()).
|
|
||||||
Msg("node has no valid owner, skipping user profile generation")
|
Msg("node has no valid owner, skipping user profile generation")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := user.Model().ID
|
userID := user.Model().ID
|
||||||
userMap[userID] = &user
|
userMap[userID] = &user
|
||||||
ids = append(ids, userID)
|
ids = append(ids, userID)
|
||||||
|
|
||||||
for _, peer := range peers.All() {
|
for _, peer := range peers.All() {
|
||||||
peerUser := peer.Owner()
|
peerUser := peer.Owner()
|
||||||
if !peerUser.Valid() {
|
if !peerUser.Valid() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
peerUserID := peerUser.Model().ID
|
peerUserID := peerUser.Model().ID
|
||||||
userMap[peerUserID] = &peerUser
|
userMap[peerUserID] = &peerUser
|
||||||
ids = append(ids, peerUserID)
|
ids = append(ids, peerUserID)
|
||||||
@@ -100,7 +102,9 @@ func generateUserProfiles(
|
|||||||
|
|
||||||
slices.Sort(ids)
|
slices.Sort(ids)
|
||||||
ids = slices.Compact(ids)
|
ids = slices.Compact(ids)
|
||||||
|
|
||||||
var profiles []tailcfg.UserProfile
|
var profiles []tailcfg.UserProfile
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
if userMap[id] != nil {
|
if userMap[id] != nil {
|
||||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||||
@@ -150,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// fullMapResponse returns a MapResponse for the given node.
|
// fullMapResponse returns a MapResponse for the given node.
|
||||||
|
//
|
||||||
|
//nolint:unused
|
||||||
func (m *mapper) fullMapResponse(
|
func (m *mapper) fullMapResponse(
|
||||||
nodeID types.NodeID,
|
nodeID types.NodeID,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
@@ -317,6 +323,7 @@ func writeDebugMapResponse(
|
|||||||
|
|
||||||
perms := fs.FileMode(debugMapResponsePerm)
|
perms := fs.FileMode(debugMapResponsePerm)
|
||||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||||
|
|
||||||
err = os.MkdirAll(mPath, perms)
|
err = os.MkdirAll(mPath, perms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -329,7 +336,8 @@ func writeDebugMapResponse(
|
|||||||
fmt.Sprintf("%s-%s.json", now, t),
|
fmt.Sprintf("%s-%s.json", now, t),
|
||||||
)
|
)
|
||||||
|
|
||||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
log.Trace().Msgf("writing MapResponse to %s", mapResponsePath)
|
||||||
|
|
||||||
err = os.WriteFile(mapResponsePath, body, perms)
|
err = os.WriteFile(mapResponsePath, body, perms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -338,7 +346,7 @@ func writeDebugMapResponse(
|
|||||||
|
|
||||||
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||||
if debugDumpMapResponsePath == "" {
|
if debugDumpMapResponsePath == "" {
|
||||||
return nil, nil
|
return nil, nil //nolint:nilnil // intentional: no data when debug path not set
|
||||||
}
|
}
|
||||||
|
|
||||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||||
@@ -351,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
result := make(map[types.NodeID][]tailcfg.MapResponse)
|
result := make(map[types.NodeID][]tailcfg.MapResponse)
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if !node.IsDir() {
|
if !node.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -358,7 +367,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
|||||||
|
|
||||||
nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64)
|
nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name())
|
log.Error().Err(err).Msgf("parsing node ID from dir %s", node.Name())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,7 +375,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
|||||||
|
|
||||||
files, err := os.ReadDir(path.Join(dir, node.Name()))
|
files, err := os.ReadDir(path.Join(dir, node.Name()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("Reading dir %s", node.Name())
|
log.Error().Err(err).Msgf("reading dir %s", node.Name())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,14 +390,15 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
|||||||
|
|
||||||
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
|
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("Reading file %s", file.Name())
|
log.Error().Err(err).Msgf("reading file %s", file.Name())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp tailcfg.MapResponse
|
var resp tailcfg.MapResponse
|
||||||
|
|
||||||
err = json.Unmarshal(body, &resp)
|
err = json.Unmarshal(body, &resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())
|
log.Error().Err(err).Msgf("unmarshalling file %s", file.Name())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,13 @@ package mapper
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/routes"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var iap = func(ipStr string) *netip.Addr {
|
var iap = func(ipStr string) *netip.Addr {
|
||||||
@@ -51,7 +46,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||||||
mach := func(hostname, username string, userid uint) *types.Node {
|
mach := func(hostname, username string, userid uint) *types.Node {
|
||||||
return &types.Node{
|
return &types.Node{
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
UserID: ptr.To(userid),
|
UserID: new(userid),
|
||||||
User: &types.User{
|
User: &types.User{
|
||||||
Name: username,
|
Name: username,
|
||||||
},
|
},
|
||||||
@@ -81,90 +76,3 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockState is a mock implementation that provides the required methods.
|
|
||||||
type mockState struct {
|
|
||||||
polMan policy.PolicyManager
|
|
||||||
derpMap *tailcfg.DERPMap
|
|
||||||
primary *routes.PrimaryRoutes
|
|
||||||
nodes types.Nodes
|
|
||||||
peers types.Nodes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) DERPMap() *tailcfg.DERPMap {
|
|
||||||
return m.derpMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
|
||||||
if m.polMan == nil {
|
|
||||||
return tailcfg.FilterAllowAll, nil
|
|
||||||
}
|
|
||||||
return m.polMan.Filter()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
|
||||||
if m.polMan == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return m.polMan.SSHPolicy(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
|
||||||
if m.polMan == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return m.polMan.NodeCanHaveTag(node, tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
|
|
||||||
if m.primary == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m.primary.PrimaryRoutes(nodeID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
|
||||||
if len(peerIDs) > 0 {
|
|
||||||
// Filter peers by the provided IDs
|
|
||||||
var filtered types.Nodes
|
|
||||||
for _, peer := range m.peers {
|
|
||||||
if slices.Contains(peerIDs, peer.ID) {
|
|
||||||
filtered = append(filtered, peer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filtered, nil
|
|
||||||
}
|
|
||||||
// Return all peers except the node itself
|
|
||||||
var filtered types.Nodes
|
|
||||||
for _, peer := range m.peers {
|
|
||||||
if peer.ID != nodeID {
|
|
||||||
filtered = append(filtered, peer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filtered, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
|
||||||
if len(nodeIDs) > 0 {
|
|
||||||
// Filter nodes by the provided IDs
|
|
||||||
var filtered types.Nodes
|
|
||||||
for _, node := range m.nodes {
|
|
||||||
if slices.Contains(nodeIDs, node.ID) {
|
|
||||||
filtered = append(filtered, node)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filtered, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.nodes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_fullMapResponse(t *testing.T) {
|
|
||||||
t.Skip("Test needs to be refactored for new state-based architecture")
|
|
||||||
// TODO: Refactor this test to work with the new state-based mapper
|
|
||||||
// The test architecture needs to be updated to work with the state interface
|
|
||||||
// instead of the old direct dependency injection pattern
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ import (
|
|||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTailNode(t *testing.T) {
|
func TestTailNode(t *testing.T) {
|
||||||
mustNK := func(str string) key.NodePublic {
|
mustNK := func(str string) key.NodePublic {
|
||||||
var k key.NodePublic
|
var k key.NodePublic
|
||||||
|
|
||||||
_ = k.UnmarshalText([]byte(str))
|
_ = k.UnmarshalText([]byte(str))
|
||||||
|
|
||||||
return k
|
return k
|
||||||
@@ -26,6 +26,7 @@ func TestTailNode(t *testing.T) {
|
|||||||
|
|
||||||
mustDK := func(str string) key.DiscoPublic {
|
mustDK := func(str string) key.DiscoPublic {
|
||||||
var k key.DiscoPublic
|
var k key.DiscoPublic
|
||||||
|
|
||||||
_ = k.UnmarshalText([]byte(str))
|
_ = k.UnmarshalText([]byte(str))
|
||||||
|
|
||||||
return k
|
return k
|
||||||
@@ -33,6 +34,7 @@ func TestTailNode(t *testing.T) {
|
|||||||
|
|
||||||
mustMK := func(str string) key.MachinePublic {
|
mustMK := func(str string) key.MachinePublic {
|
||||||
var k key.MachinePublic
|
var k key.MachinePublic
|
||||||
|
|
||||||
_ = k.UnmarshalText([]byte(str))
|
_ = k.UnmarshalText([]byte(str))
|
||||||
|
|
||||||
return k
|
return k
|
||||||
@@ -95,7 +97,7 @@ func TestTailNode(t *testing.T) {
|
|||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
Hostname: "mini",
|
Hostname: "mini",
|
||||||
GivenName: "mini",
|
GivenName: "mini",
|
||||||
UserID: ptr.To(uint(0)),
|
UserID: new(uint(0)),
|
||||||
User: &types.User{
|
User: &types.User{
|
||||||
Name: "mini",
|
Name: "mini",
|
||||||
},
|
},
|
||||||
@@ -137,8 +139,8 @@ func TestTailNode(t *testing.T) {
|
|||||||
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
|
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
|
||||||
AllowedIPs: []netip.Prefix{
|
AllowedIPs: []netip.Prefix{
|
||||||
tsaddr.AllIPv4(),
|
tsaddr.AllIPv4(),
|
||||||
netip.MustParsePrefix("192.168.0.0/24"),
|
|
||||||
netip.MustParsePrefix("100.64.0.1/32"),
|
netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
tsaddr.AllIPv6(),
|
tsaddr.AllIPv6(),
|
||||||
},
|
},
|
||||||
PrimaryRoutes: []netip.Prefix{
|
PrimaryRoutes: []netip.Prefix{
|
||||||
@@ -255,7 +257,7 @@ func TestNodeExpiry(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "localtime",
|
name: "localtime",
|
||||||
exp: tp(time.Time{}.Local()),
|
exp: tp(time.Time{}.Local()), //nolint:gosmopolitan
|
||||||
wantTimeZero: true,
|
wantTimeZero: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -284,7 +286,9 @@ func TestNodeExpiry(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("nodeExpiry() error = %v", err)
|
t.Fatalf("nodeExpiry() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var deseri tailcfg.Node
|
var deseri tailcfg.Node
|
||||||
|
|
||||||
err = json.Unmarshal(seri, &deseri)
|
err = json.Unmarshal(seri, &deseri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("nodeExpiry() error = %v", err)
|
t.Fatalf("nodeExpiry() error = %v", err)
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
|
|||||||
rw := &respWriterProm{ResponseWriter: w}
|
rw := &respWriterProm{ResponseWriter: w}
|
||||||
|
|
||||||
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
|
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
|
||||||
|
|
||||||
next.ServeHTTP(rw, r)
|
next.ServeHTTP(rw, r)
|
||||||
timer.ObserveDuration()
|
timer.ObserveDuration()
|
||||||
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
|
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
|
||||||
@@ -79,6 +80,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
type respWriterProm struct {
|
type respWriterProm struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
|
|
||||||
status int
|
status int
|
||||||
written int64
|
written int64
|
||||||
wroteHeader bool
|
wroteHeader bool
|
||||||
@@ -94,6 +96,7 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
|
|||||||
if !r.wroteHeader {
|
if !r.wroteHeader {
|
||||||
r.WriteHeader(http.StatusOK)
|
r.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := r.ResponseWriter.Write(b)
|
n, err := r.ResponseWriter.Write(b)
|
||||||
r.written += int64(n)
|
r.written += int64(n)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import (
|
|||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
|
||||||
|
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
||||||
ts2021UpgradePath = "/ts2021"
|
ts2021UpgradePath = "/ts2021"
|
||||||
@@ -51,7 +54,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", req.RemoteAddr)
|
log.Trace().Caller().Msgf("noise upgrade handler for client %s", req.RemoteAddr)
|
||||||
|
|
||||||
upgrade := req.Header.Get("Upgrade")
|
upgrade := req.Header.Get("Upgrade")
|
||||||
if upgrade == "" {
|
if upgrade == "" {
|
||||||
@@ -60,7 +63,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||||||
// be passed to Headscale. Let's give them a hint.
|
// be passed to Headscale. Let's give them a hint.
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Caller().
|
Caller().
|
||||||
Msg("No Upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
Msg("no upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
||||||
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -79,7 +82,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||||||
noiseServer.earlyNoise,
|
noiseServer.earlyNoise,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, fmt.Errorf("noise upgrade failed: %w", err))
|
httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
|
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
|
||||||
return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version)
|
return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||||
@@ -137,17 +140,20 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
|||||||
// an HTTP/2 settings frame, which isn't of type 'T')
|
// an HTTP/2 settings frame, which isn't of type 'T')
|
||||||
var notH2Frame [5]byte
|
var notH2Frame [5]byte
|
||||||
copy(notH2Frame[:], earlyPayloadMagic)
|
copy(notH2Frame[:], earlyPayloadMagic)
|
||||||
|
|
||||||
var lenBuf [4]byte
|
var lenBuf [4]byte
|
||||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON)))
|
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) //nolint:gosec // JSON length is bounded
|
||||||
// These writes are all buffered by caller, so fine to do them
|
// These writes are all buffered by caller, so fine to do them
|
||||||
// separately:
|
// separately:
|
||||||
if _, err := writer.Write(notH2Frame[:]); err != nil {
|
if _, err := writer.Write(notH2Frame[:]); err != nil { //nolint:noinlineerr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := writer.Write(lenBuf[:]); err != nil {
|
|
||||||
|
if _, err := writer.Write(lenBuf[:]); err != nil { //nolint:noinlineerr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := writer.Write(earlyJSON); err != nil {
|
|
||||||
|
if _, err := writer.Write(earlyJSON); err != nil { //nolint:noinlineerr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,7 +205,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||||||
body, _ := io.ReadAll(req.Body)
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
|
||||||
var mapRequest tailcfg.MapRequest
|
var mapRequest tailcfg.MapRequest
|
||||||
if err := json.Unmarshal(body, &mapRequest); err != nil {
|
if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -218,7 +224,8 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||||||
ns.nodeKey = nv.NodeKey()
|
ns.nodeKey = nv.NodeKey()
|
||||||
|
|
||||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
sess.log.Trace().Caller().Msg("a node sending a MapRequest with Noise protocol")
|
||||||
|
|
||||||
if !sess.isStreaming() {
|
if !sess.isStreaming() {
|
||||||
sess.serve()
|
sess.serve()
|
||||||
} else {
|
} else {
|
||||||
@@ -241,14 +248,16 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
|
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
|
||||||
var resp *tailcfg.RegisterResponse
|
var resp *tailcfg.RegisterResponse
|
||||||
|
|
||||||
body, err := io.ReadAll(req.Body)
|
body, err := io.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &tailcfg.RegisterRequest{}, regErr(err)
|
return &tailcfg.RegisterRequest{}, regErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var regReq tailcfg.RegisterRequest
|
var regReq tailcfg.RegisterRequest
|
||||||
if err := json.Unmarshal(body, ®Req); err != nil {
|
if err := json.Unmarshal(body, ®Req); err != nil { //nolint:noinlineerr
|
||||||
return ®Req, regErr(err)
|
return ®Req, regErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,11 +265,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||||||
|
|
||||||
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
|
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var httpErr HTTPError
|
if httpErr, ok := errors.AsType[HTTPError](err); ok {
|
||||||
if errors.As(err, &httpErr) {
|
|
||||||
resp = &tailcfg.RegisterResponse{
|
resp = &tailcfg.RegisterResponse{
|
||||||
Error: httpErr.Msg,
|
Error: httpErr.Msg,
|
||||||
}
|
}
|
||||||
|
|
||||||
return ®Req, resp
|
return ®Req, resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,8 +287,9 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
err := json.NewEncoder(writer).Encode(registerResponse)
|
||||||
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
if err != nil {
|
||||||
|
log.Error().Caller().Err(err).Msg("noise registration handler: failed to encode RegisterResponse")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
|
||||||
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
|
errNoOIDCIDToken = errors.New("extracting ID token")
|
||||||
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
|
errNoOIDCRegistrationInfo = errors.New("registration info not in cache")
|
||||||
errOIDCAllowedDomains = errors.New(
|
errOIDCAllowedDomains = errors.New(
|
||||||
"authenticated principal does not match any allowed domain",
|
"authenticated principal does not match any allowed domain",
|
||||||
)
|
)
|
||||||
@@ -68,7 +68,7 @@ func NewAuthProviderOIDC(
|
|||||||
) (*AuthProviderOIDC, error) {
|
) (*AuthProviderOIDC, error) {
|
||||||
var err error
|
var err error
|
||||||
// grab oidc config if it hasn't been already
|
// grab oidc config if it hasn't been already
|
||||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) //nolint:contextcheck
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||||
}
|
}
|
||||||
@@ -163,13 +163,14 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||||||
for k, v := range a.cfg.ExtraParams {
|
for k, v := range a.cfg.ExtraParams {
|
||||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
extras = append(extras, oidc.Nonce(nonce))
|
extras = append(extras, oidc.Nonce(nonce))
|
||||||
|
|
||||||
// Cache the registration info
|
// Cache the registration info
|
||||||
a.registrationCache.Set(state, registrationInfo)
|
a.registrationCache.Set(state, registrationInfo)
|
||||||
|
|
||||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||||
log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL)
|
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||||
|
|
||||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||||
}
|
}
|
||||||
@@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
stateCookieName := getCookieName("state", state)
|
stateCookieName := getCookieName("state", state)
|
||||||
|
|
||||||
cookieState, err := req.Cookie(stateCookieName)
|
cookieState, err := req.Cookie(stateCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||||
@@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if idToken.Nonce == "" {
|
if idToken.Nonce == "" {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
||||||
|
|
||||||
nonce, err := req.Cookie(nonceCookieName)
|
nonce, err := req.Cookie(nonceCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if idToken.Nonce != nonce.Value {
|
if idToken.Nonce != nonce.Value {
|
||||||
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
||||||
return
|
return
|
||||||
@@ -231,7 +236,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
||||||
|
|
||||||
var claims types.OIDCClaims
|
var claims types.OIDCClaims
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
if err := idToken.Claims(&claims); err != nil { //nolint:noinlineerr
|
||||||
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
|
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
|
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
|
||||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||||
var userinfo *oidc.UserInfo
|
var userinfo *oidc.UserInfo
|
||||||
|
|
||||||
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
|
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||||
@@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
|
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
|
||||||
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
|
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
|
||||||
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
|
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
|
||||||
|
|
||||||
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
|
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
|
||||||
if userinfo2.Groups != nil {
|
if userinfo2.Groups != nil {
|
||||||
claims.Groups = userinfo2.Groups
|
claims.Groups = userinfo2.Groups
|
||||||
@@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
Msgf("could not create or update user")
|
Msgf("could not create or update user")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
writer.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
||||||
_, werr := writer.Write([]byte("Could not create or update user"))
|
_, werr := writer.Write([]byte("Could not create or update user"))
|
||||||
if werr != nil {
|
if werr != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
// Register the node if it does not exist.
|
// Register the node if it does not exist.
|
||||||
if registrationId != nil {
|
if registrationId != nil {
|
||||||
verb := "Reauthenticated"
|
verb := "Reauthenticated"
|
||||||
|
|
||||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||||
@@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -316,15 +327,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): replace with go-elem
|
// TODO(kradalby): replace with go-elem
|
||||||
content, err := renderOIDCCallbackTemplate(user, verb)
|
content := renderOIDCCallbackTemplate(user, verb)
|
||||||
if err != nil {
|
|
||||||
httpError(writer, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
|
||||||
|
if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr
|
||||||
util.LogErr(err, "Failed to write HTTP response")
|
util.LogErr(err, "Failed to write HTTP response")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,6 +378,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if regInfo.Verifier != nil {
|
if regInfo.Verifier != nil {
|
||||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||||
}
|
}
|
||||||
@@ -377,7 +386,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||||||
|
|
||||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
|
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("exchanging code for token: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return oauth2Token, err
|
return oauth2Token, err
|
||||||
@@ -394,9 +403,10 @@ func (a *AuthProviderOIDC) extractIDToken(
|
|||||||
}
|
}
|
||||||
|
|
||||||
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||||
|
|
||||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
|
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("verifying ID token: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return idToken, nil
|
return idToken, nil
|
||||||
@@ -516,6 +526,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
newUser bool
|
newUser bool
|
||||||
c change.Change
|
c change.Change
|
||||||
)
|
)
|
||||||
|
|
||||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
||||||
@@ -561,7 +572,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
|||||||
util.RegisterMethodOIDC,
|
util.RegisterMethodOIDC,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("could not register node: %w", err)
|
return false, fmt.Errorf("registering node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||||
@@ -589,9 +600,9 @@ func (a *AuthProviderOIDC) handleRegistration(
|
|||||||
func renderOIDCCallbackTemplate(
|
func renderOIDCCallbackTemplate(
|
||||||
user *types.User,
|
user *types.User,
|
||||||
verb string,
|
verb string,
|
||||||
) (*bytes.Buffer, error) {
|
) *bytes.Buffer {
|
||||||
html := templates.OIDCCallback(user.Display(), verb).Render()
|
html := templates.OIDCCallback(user.Display(), verb).Render()
|
||||||
return bytes.NewBufferString(html), nil
|
return bytes.NewBufferString(html)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCookieName generates a unique cookie name based on a cookie value.
|
// getCookieName generates a unique cookie name based on a cookie value.
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
|
|||||||
) {
|
) {
|
||||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
_, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
|
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
|
||||||
@@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage(
|
|||||||
) {
|
) {
|
||||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
_, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) ApplePlatformConfig(
|
func (h *Headscale) ApplePlatformConfig(
|
||||||
@@ -37,6 +37,7 @@ func (h *Headscale) ApplePlatformConfig(
|
|||||||
req *http.Request,
|
req *http.Request,
|
||||||
) {
|
) {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
|
|
||||||
platform, ok := vars["platform"]
|
platform, ok := vars["platform"]
|
||||||
if !ok {
|
if !ok {
|
||||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
|
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
|
||||||
@@ -64,17 +65,20 @@ func (h *Headscale) ApplePlatformConfig(
|
|||||||
|
|
||||||
switch platform {
|
switch platform {
|
||||||
case "macos-standalone":
|
case "macos-standalone":
|
||||||
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
|
err := macosStandaloneTemplate.Execute(&payload, platformConfig)
|
||||||
|
if err != nil {
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "macos-app-store":
|
case "macos-app-store":
|
||||||
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
|
err := macosAppStoreTemplate.Execute(&payload, platformConfig)
|
||||||
|
if err != nil {
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "ios":
|
case "ios":
|
||||||
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
|
err := iosTemplate.Execute(&payload, platformConfig)
|
||||||
|
if err != nil {
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -90,7 +94,7 @@ func (h *Headscale) ApplePlatformConfig(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var content bytes.Buffer
|
var content bytes.Buffer
|
||||||
if err := commonTemplate.Execute(&content, config); err != nil {
|
if err := commonTemplate.Execute(&content, config); err != nil { //nolint:noinlineerr
|
||||||
httpError(writer, err)
|
httpError(writer, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -98,7 +102,7 @@ func (h *Headscale) ApplePlatformConfig(
|
|||||||
writer.Header().
|
writer.Header().
|
||||||
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
|
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
writer.Write(content.Bytes())
|
_, _ = writer.Write(content.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppleMobileConfig struct {
|
type AppleMobileConfig struct {
|
||||||
|
|||||||
@@ -16,15 +16,18 @@ type Match struct {
|
|||||||
dests *netipx.IPSet
|
dests *netipx.IPSet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Match) DebugString() string {
|
func (m *Match) DebugString() string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
sb.WriteString("Match:\n")
|
sb.WriteString("Match:\n")
|
||||||
sb.WriteString(" Sources:\n")
|
sb.WriteString(" Sources:\n")
|
||||||
|
|
||||||
for _, prefix := range m.srcs.Prefixes() {
|
for _, prefix := range m.srcs.Prefixes() {
|
||||||
sb.WriteString(" " + prefix.String() + "\n")
|
sb.WriteString(" " + prefix.String() + "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(" Destinations:\n")
|
sb.WriteString(" Destinations:\n")
|
||||||
|
|
||||||
for _, prefix := range m.dests.Prefixes() {
|
for _, prefix := range m.dests.Prefixes() {
|
||||||
sb.WriteString(" " + prefix.String() + "\n")
|
sb.WriteString(" " + prefix.String() + "\n")
|
||||||
}
|
}
|
||||||
@@ -42,7 +45,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||||
dests := []string{}
|
dests := make([]string, 0, len(rule.DstPorts))
|
||||||
for _, dest := range rule.DstPorts {
|
for _, dest := range rule.DstPorts {
|
||||||
dests = append(dests, dest.IP)
|
dests = append(dests, dest.IP)
|
||||||
}
|
}
|
||||||
@@ -93,11 +96,24 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
|
|||||||
return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix)
|
return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestsIsTheInternet reports if the destination is equal to "the internet"
|
// DestsIsTheInternet reports if the destination contains "the internet"
|
||||||
// which is a IPSet that represents "autogroup:internet" and is special
|
// which is a IPSet that represents "autogroup:internet" and is special
|
||||||
// cased for exit nodes.
|
// cased for exit nodes.
|
||||||
func (m Match) DestsIsTheInternet() bool {
|
// This checks if dests is a superset of TheInternet(), which handles
|
||||||
return m.dests.Equal(util.TheInternet()) ||
|
// merged filter rules where TheInternet is combined with other destinations.
|
||||||
m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
func (m *Match) DestsIsTheInternet() bool {
|
||||||
m.dests.ContainsPrefix(tsaddr.AllIPv6())
|
if m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
||||||
|
m.dests.ContainsPrefix(tsaddr.AllIPv6()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if dests contains all prefixes of TheInternet (superset check)
|
||||||
|
theInternet := util.TheInternet()
|
||||||
|
for _, prefix := range theInternet.Prefixes() {
|
||||||
|
if !m.dests.ContainsPrefix(prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,18 +19,18 @@ type PolicyManager interface {
|
|||||||
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
|
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
|
||||||
// BuildPeerMap constructs peer relationship maps for the given nodes
|
// BuildPeerMap constructs peer relationship maps for the given nodes
|
||||||
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
|
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
|
||||||
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
|
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||||
SetPolicy([]byte) (bool, error)
|
SetPolicy(pol []byte) (bool, error)
|
||||||
SetUsers(users []types.User) (bool, error)
|
SetUsers(users []types.User) (bool, error)
|
||||||
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
||||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||||
NodeCanHaveTag(types.NodeView, string) bool
|
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||||
|
|
||||||
// TagExists reports whether the given tag is defined in the policy.
|
// TagExists reports whether the given tag is defined in the policy.
|
||||||
TagExists(tag string) bool
|
TagExists(tag string) bool
|
||||||
|
|
||||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||||
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
|
NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool
|
||||||
|
|
||||||
Version() int
|
Version() int
|
||||||
DebugString() string
|
DebugString() string
|
||||||
@@ -38,8 +38,11 @@ type PolicyManager interface {
|
|||||||
|
|
||||||
// NewPolicyManager returns a new policy manager.
|
// NewPolicyManager returns a new policy manager.
|
||||||
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||||
var polMan PolicyManager
|
var (
|
||||||
var err error
|
polMan PolicyManager
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
polMans = append(polMans, pm)
|
polMans = append(polMans, pm)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
|
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||||
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
|
polmanFuncs := make([]func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error), 0, 1)
|
||||||
|
|
||||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
|
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||||
return policyv2.NewPolicyManager(pol, u, n)
|
return policyv2.NewPolicyManager(pol, u, n)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"tailscale.com/net/tsaddr"
|
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sort and deduplicate
|
// Sort and deduplicate
|
||||||
tsaddr.SortPrefixes(newApproved)
|
slices.SortFunc(newApproved, netip.Prefix.Compare)
|
||||||
newApproved = slices.Compact(newApproved)
|
newApproved = slices.Compact(newApproved)
|
||||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||||
return route.IsValid()
|
return route.IsValid()
|
||||||
@@ -120,12 +119,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
|||||||
// Sort the current approved for comparison
|
// Sort the current approved for comparison
|
||||||
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||||
copy(sortedCurrent, currentApproved)
|
copy(sortedCurrent, currentApproved)
|
||||||
tsaddr.SortPrefixes(sortedCurrent)
|
slices.SortFunc(sortedCurrent, netip.Prefix.Compare)
|
||||||
|
|
||||||
// Only update if the routes actually changed
|
// Only update if the routes actually changed
|
||||||
if !slices.Equal(sortedCurrent, newApproved) {
|
if !slices.Equal(sortedCurrent, newApproved) {
|
||||||
// Log what changed
|
// Log what changed
|
||||||
var added, kept []netip.Prefix
|
var added, kept []netip.Prefix
|
||||||
|
|
||||||
for _, route := range newApproved {
|
for _, route := range newApproved {
|
||||||
if !slices.Contains(sortedCurrent, route) {
|
if !slices.Contains(sortedCurrent, route) {
|
||||||
added = append(added, route)
|
added = append(added, route)
|
||||||
@@ -136,8 +136,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
|||||||
|
|
||||||
if len(added) > 0 {
|
if len(added) > 0 {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Uint64("node.id", nv.ID().Uint64()).
|
EmbedObject(nv).
|
||||||
Str("node.name", nv.Hostname()).
|
|
||||||
Strs("routes.added", util.PrefixesToString(added)).
|
Strs("routes.added", util.PrefixesToString(added)).
|
||||||
Strs("routes.kept", util.PrefixesToString(kept)).
|
Strs("routes.kept", util.PrefixesToString(kept)).
|
||||||
Int("routes.total", len(newApproved)).
|
Int("routes.total", len(newApproved)).
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ package policy
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/net/tsaddr"
|
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,10 +32,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: "test-node",
|
Hostname: "test-node",
|
||||||
UserID: ptr.To(user1.ID),
|
UserID: new(user1.ID),
|
||||||
User: ptr.To(user1),
|
User: new(user1),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||||
Tags: []string{"tag:test"},
|
Tags: []string{"tag:test"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,10 +44,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: "other-node",
|
Hostname: "other-node",
|
||||||
UserID: ptr.To(user2.ID),
|
UserID: new(user2.ID),
|
||||||
User: ptr.To(user2),
|
User: new(user2),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
IPv4: new(netip.MustParseAddr("100.64.0.2")),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a policy that auto-approves specific routes
|
// Create a policy that auto-approves specific routes
|
||||||
@@ -76,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||||||
}`
|
}`
|
||||||
|
|
||||||
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -194,7 +194,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||||||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||||
|
|
||||||
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||||
tsaddr.SortPrefixes(tt.wantApproved)
|
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
|
||||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||||
|
|
||||||
// Verify that all previously approved routes are still present
|
// Verify that all previously approved routes are still present
|
||||||
@@ -304,20 +304,23 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
|||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: ptr.To(user.ID),
|
UserID: new(user.ID),
|
||||||
User: ptr.To(user),
|
User: new(user),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||||
ApprovedRoutes: tt.currentApproved,
|
ApprovedRoutes: tt.currentApproved,
|
||||||
}
|
}
|
||||||
nodes := types.Nodes{&node}
|
nodes := types.Nodes{&node}
|
||||||
|
|
||||||
// Create policy manager or use nil if specified
|
// Create policy manager or use nil if specified
|
||||||
var pm PolicyManager
|
var (
|
||||||
var err error
|
pm PolicyManager
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
if tt.name != "nil_policy_manager" {
|
if tt.name != "nil_policy_manager" {
|
||||||
pm, err = pmf(users, nodes.ViewSlice())
|
pm, err = pmf(users, nodes.ViewSlice())
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
pm = nil
|
pm = nil
|
||||||
}
|
}
|
||||||
@@ -330,7 +333,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
|||||||
if tt.wantApproved == nil {
|
if tt.wantApproved == nil {
|
||||||
assert.Nil(t, gotApproved, "expected nil approved routes")
|
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||||
} else {
|
} else {
|
||||||
tsaddr.SortPrefixes(tt.wantApproved)
|
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
|
||||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||||
@@ -92,8 +91,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
|||||||
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||||
nodeUser: "test",
|
nodeUser: "test",
|
||||||
wantApproved: []netip.Prefix{
|
wantApproved: []netip.Prefix{
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
|
||||||
netip.MustParsePrefix("10.0.0.0/24"),
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
netip.MustParsePrefix("192.168.0.0/24"),
|
netip.MustParsePrefix("192.168.0.0/24"),
|
||||||
},
|
},
|
||||||
wantChanged: false,
|
wantChanged: false,
|
||||||
@@ -124,8 +123,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
|||||||
nodeUser: "test",
|
nodeUser: "test",
|
||||||
nodeTags: []string{"tag:approved"},
|
nodeTags: []string{"tag:approved"},
|
||||||
wantApproved: []netip.Prefix{
|
wantApproved: []netip.Prefix{
|
||||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
|
||||||
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||||
},
|
},
|
||||||
wantChanged: true,
|
wantChanged: true,
|
||||||
},
|
},
|
||||||
@@ -168,13 +167,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
|||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: tt.nodeHostname,
|
Hostname: tt.nodeHostname,
|
||||||
UserID: ptr.To(user.ID),
|
UserID: new(user.ID),
|
||||||
User: ptr.To(user),
|
User: new(user),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tt.announcedRoutes,
|
RoutableIPs: tt.announcedRoutes,
|
||||||
},
|
},
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||||
ApprovedRoutes: tt.currentApproved,
|
ApprovedRoutes: tt.currentApproved,
|
||||||
Tags: tt.nodeTags,
|
Tags: tt.nodeTags,
|
||||||
}
|
}
|
||||||
@@ -294,13 +293,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
|||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: ptr.To(user.ID),
|
UserID: new(user.ID),
|
||||||
User: ptr.To(user),
|
User: new(user),
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tt.announcedRoutes,
|
RoutableIPs: tt.announcedRoutes,
|
||||||
},
|
},
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||||
ApprovedRoutes: tt.currentApproved,
|
ApprovedRoutes: tt.currentApproved,
|
||||||
}
|
}
|
||||||
nodes := types.Nodes{&node}
|
nodes := types.Nodes{&node}
|
||||||
@@ -331,6 +330,8 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
|||||||
Name: "test",
|
Name: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userID := user.ID
|
||||||
|
|
||||||
currentApproved := []netip.Prefix{
|
currentApproved := []netip.Prefix{
|
||||||
netip.MustParsePrefix("10.0.0.0/24"),
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
}
|
}
|
||||||
@@ -343,13 +344,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
|||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
Hostname: "testnode",
|
Hostname: "testnode",
|
||||||
UserID: ptr.To(user.ID),
|
UserID: &userID,
|
||||||
User: ptr.To(user),
|
User: &user,
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: announcedRoutes,
|
RoutableIPs: announcedRoutes,
|
||||||
},
|
},
|
||||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||||
ApprovedRoutes: currentApproved,
|
ApprovedRoutes: currentApproved,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/ptr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ap = func(ipStr string) *netip.Addr {
|
var ap = func(ipStr string) *netip.Addr {
|
||||||
@@ -33,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
|
|||||||
rules []tailcfg.FilterRule
|
rules []tailcfg.FilterRule
|
||||||
node *types.Node
|
node *types.Node
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
@@ -783,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
|
|||||||
for _, v := range gotViews.All() {
|
for _, v := range gotViews.All() {
|
||||||
got = append(got, v.AsStruct())
|
got = append(got, v.AsStruct())
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||||
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
|
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
|
||||||
t.Log("Matchers: ")
|
t.Log("Matchers: ")
|
||||||
|
|
||||||
for _, m := range matchers {
|
for _, m := range matchers {
|
||||||
t.Log("\t+", m.DebugString())
|
t.Log("\t+", m.DebugString())
|
||||||
}
|
}
|
||||||
@@ -796,7 +798,7 @@ func TestReduceNodes(t *testing.T) {
|
|||||||
|
|
||||||
func TestReduceNodesFromPolicy(t *testing.T) {
|
func TestReduceNodesFromPolicy(t *testing.T) {
|
||||||
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
|
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
|
||||||
var routes []netip.Prefix
|
routes := make([]netip.Prefix, 0, len(routess))
|
||||||
for _, route := range routess {
|
for _, route := range routess {
|
||||||
routes = append(routes, netip.MustParsePrefix(route))
|
routes = append(routes, netip.MustParsePrefix(route))
|
||||||
}
|
}
|
||||||
@@ -891,11 +893,13 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}`,
|
}`,
|
||||||
node: n(1, "100.64.0.1", "mobile", "mobile"),
|
node: n(1, "100.64.0.1", "mobile", "mobile"),
|
||||||
|
// autogroup:internet does not generate packet filters - it's handled
|
||||||
|
// by exit node routing via AllowedIPs, not by packet filtering.
|
||||||
|
// Only server is visible through the mobile -> server:80 rule.
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
n(2, "100.64.0.2", "server", "server"),
|
n(2, "100.64.0.2", "server", "server"),
|
||||||
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
|
||||||
},
|
},
|
||||||
wantMatchers: 2,
|
wantMatchers: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "2788-exit-node-0000-route",
|
name: "2788-exit-node-0000-route",
|
||||||
@@ -938,7 +942,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||||||
n(2, "100.64.0.2", "server", "server"),
|
n(2, "100.64.0.2", "server", "server"),
|
||||||
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
||||||
},
|
},
|
||||||
wantMatchers: 2,
|
wantMatchers: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "2788-exit-node-::0-route",
|
name: "2788-exit-node-::0-route",
|
||||||
@@ -981,7 +985,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||||||
n(2, "100.64.0.2", "server", "server"),
|
n(2, "100.64.0.2", "server", "server"),
|
||||||
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
|
||||||
},
|
},
|
||||||
wantMatchers: 2,
|
wantMatchers: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "2784-split-exit-node-access",
|
name: "2784-split-exit-node-access",
|
||||||
@@ -1032,8 +1036,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||||
var pm PolicyManager
|
var (
|
||||||
var err error
|
pm PolicyManager
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
pm, err = pmf(nil, tt.nodes.ViewSlice())
|
pm, err = pmf(nil, tt.nodes.ViewSlice())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -1051,9 +1058,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||||||
for _, v := range gotViews.All() {
|
for _, v := range gotViews.All() {
|
||||||
got = append(got, v.AsStruct())
|
got = append(got, v.AsStruct())
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||||
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
|
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
|
||||||
t.Log("Matchers: ")
|
t.Log("Matchers: ")
|
||||||
|
|
||||||
for _, m := range matchers {
|
for _, m := range matchers {
|
||||||
t.Log("\t+", m.DebugString())
|
t.Log("\t+", m.DebugString())
|
||||||
}
|
}
|
||||||
@@ -1074,21 +1083,21 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||||||
nodeUser1 := types.Node{
|
nodeUser1 := types.Node{
|
||||||
Hostname: "user1-device",
|
Hostname: "user1-device",
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
UserID: ptr.To(uint(1)),
|
UserID: new(uint(1)),
|
||||||
User: ptr.To(users[0]),
|
User: new(users[0]),
|
||||||
}
|
}
|
||||||
nodeUser2 := types.Node{
|
nodeUser2 := types.Node{
|
||||||
Hostname: "user2-device",
|
Hostname: "user2-device",
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
UserID: ptr.To(uint(2)),
|
UserID: new(uint(2)),
|
||||||
User: ptr.To(users[1]),
|
User: new(users[1]),
|
||||||
}
|
}
|
||||||
|
|
||||||
taggedClient := types.Node{
|
taggedClient := types.Node{
|
||||||
Hostname: "tagged-client",
|
Hostname: "tagged-client",
|
||||||
IPv4: ap("100.64.0.4"),
|
IPv4: ap("100.64.0.4"),
|
||||||
UserID: ptr.To(uint(2)),
|
UserID: new(uint(2)),
|
||||||
User: ptr.To(users[1]),
|
User: new(users[1]),
|
||||||
Tags: []string{"tag:client"},
|
Tags: []string{"tag:client"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1096,8 +1105,8 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||||||
nodeTaggedServer := types.Node{
|
nodeTaggedServer := types.Node{
|
||||||
Hostname: "tagged-server",
|
Hostname: "tagged-server",
|
||||||
IPv4: ap("100.64.0.5"),
|
IPv4: ap("100.64.0.5"),
|
||||||
UserID: ptr.To(uint(1)),
|
UserID: new(uint(1)),
|
||||||
User: ptr.To(users[0]),
|
User: new(users[0]),
|
||||||
Tags: []string{"tag:server"},
|
Tags: []string{"tag:server"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1231,7 +1240,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}`,
|
}`,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
|
errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid-check-period",
|
name: "invalid-check-period",
|
||||||
@@ -1278,7 +1287,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}`,
|
}`,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
errorMessage: "autogroup \"autogroup:invalid\" is not supported",
|
errorMessage: "autogroup not supported for SSH user",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
|
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
|
||||||
@@ -1451,13 +1460,17 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||||
var pm PolicyManager
|
var (
|
||||||
var err error
|
pm PolicyManager
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||||
|
|
||||||
if tt.expectErr {
|
if tt.expectErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), tt.errorMessage)
|
require.Contains(t, err.Error(), tt.errorMessage)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1480,6 +1493,7 @@ func TestReduceRoutes(t *testing.T) {
|
|||||||
routes []netip.Prefix
|
routes []netip.Prefix
|
||||||
rules []tailcfg.FilterRule
|
rules []tailcfg.FilterRule
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
@@ -2101,6 +2115,7 @@ func TestReduceRoutes(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||||
|
|
||||||
got := ReduceRoutes(
|
got := ReduceRoutes(
|
||||||
tt.args.node.View(),
|
tt.args.node.View(),
|
||||||
tt.args.routes,
|
tt.args.routes,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
|||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
// record if the rule is actually relevant for the given node.
|
// record if the rule is actually relevant for the given node.
|
||||||
var dests []tailcfg.NetPortRange
|
var dests []tailcfg.NetPortRange
|
||||||
|
|
||||||
DEST_LOOP:
|
DEST_LOOP:
|
||||||
for _, dest := range rule.DstPorts {
|
for _, dest := range rule.DstPorts {
|
||||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user