mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-15 04:07:40 +01:00
Compare commits
12 Commits
kradalby/c
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab70b4e37e | ||
|
|
a058bf3cd3 | ||
|
|
b2a18830ed | ||
|
|
9779adc0b7 | ||
|
|
e7fe645be5 | ||
|
|
bcd80ee773 | ||
|
|
c04e17d82e | ||
|
|
98fc0563ac | ||
|
|
3123d5286b | ||
|
|
7fce5065c4 | ||
|
|
a98d9bd05f | ||
|
|
46c59a3fff |
@@ -21,4 +21,3 @@ LICENSE
|
||||
node_modules/
|
||||
package-lock.json
|
||||
package.json
|
||||
|
||||
|
||||
6
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
6
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
@@ -77,6 +77,10 @@ body:
|
||||
attributes:
|
||||
label: Debug information
|
||||
description: |
|
||||
Please have a look at our [Debugging and troubleshooting
|
||||
guide](https://headscale.net/development/ref/debug/) to learn about
|
||||
common debugging techniques.
|
||||
|
||||
Links? References? Anything that will give us more context about the issue you are encountering.
|
||||
If **any** of these are omitted we will likely close your issue, do **not** ignore them.
|
||||
|
||||
@@ -92,7 +96,7 @@ body:
|
||||
`tailscale status --json > DESCRIPTIVE_NAME.json`
|
||||
|
||||
Get the logs of a Tailscale client that is not working as expected.
|
||||
`tailscale daemon-logs`
|
||||
`tailscale debug daemon-logs`
|
||||
|
||||
Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
|
||||
**Ensure** you use formatting for files you attach.
|
||||
|
||||
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@@ -79,11 +79,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
env:
|
||||
- "GOARCH=arm GOOS=linux GOARM=5"
|
||||
- "GOARCH=arm GOOS=linux GOARM=6"
|
||||
- "GOARCH=arm GOOS=linux GOARM=7"
|
||||
- "GOARCH=arm64 GOOS=linux"
|
||||
- "GOARCH=386 GOOS=linux"
|
||||
- "GOARCH=amd64 GOOS=linux"
|
||||
- "GOARCH=arm64 GOOS=darwin"
|
||||
- "GOARCH=amd64 GOOS=darwin"
|
||||
|
||||
55
.github/workflows/check-generated.yml
vendored
Normal file
55
.github/workflows/check-generated.yml
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
name: Check Generated Files
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check-generated:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
with:
|
||||
filters: |
|
||||
files:
|
||||
- '*.nix'
|
||||
- 'go.*'
|
||||
- '**/*.go'
|
||||
- '**/*.proto'
|
||||
- 'buf.gen.yaml'
|
||||
- 'tools/**'
|
||||
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
|
||||
|
||||
- name: Run make generate
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run: nix develop --command -- make generate
|
||||
|
||||
- name: Check for uncommitted changes
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run: |
|
||||
if ! git diff --exit-code; then
|
||||
echo "❌ Generated files are not up to date!"
|
||||
echo "Please run 'make generate' and commit the changes."
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All generated files are up to date."
|
||||
fi
|
||||
@@ -77,7 +77,7 @@ jobs:
|
||||
attempt_delay: 300000 # 5 min
|
||||
attempt_limit: 2
|
||||
command: |
|
||||
nix develop --command -- hi run "^${{ inputs.test }}$" \
|
||||
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=500 "^${{ inputs.test }}$" \
|
||||
--timeout=120m \
|
||||
${{ inputs.postgres_flag }}
|
||||
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
|
||||
5
.github/workflows/lint.yml
vendored
5
.github/workflows/lint.yml
vendored
@@ -38,7 +38,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run: nix develop --command -- golangci-lint run
|
||||
--new-from-rev=${{github.event.pull_request.base.sha}}
|
||||
--format=colored-line-number
|
||||
--output.text.path=stdout
|
||||
--output.text.print-linter-name
|
||||
--output.text.print-issued-lines
|
||||
--output.text.colors
|
||||
|
||||
prettier-lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,6 +1,9 @@
|
||||
ignored/
|
||||
tailscale/
|
||||
.vscode/
|
||||
.claude/
|
||||
|
||||
*.prof
|
||||
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
@@ -47,8 +50,6 @@ integration_test/etc/config.dump.yaml
|
||||
|
||||
__debug_bin
|
||||
|
||||
|
||||
node_modules/
|
||||
package-lock.json
|
||||
package.json
|
||||
|
||||
|
||||
@@ -19,12 +19,8 @@ builds:
|
||||
- darwin_amd64
|
||||
- darwin_arm64
|
||||
- freebsd_amd64
|
||||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_5
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
flags:
|
||||
- -mod=readonly
|
||||
ldflags:
|
||||
@@ -113,9 +109,7 @@ kos:
|
||||
- CGO_ENABLED=0
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/386
|
||||
- linux/arm64
|
||||
- linux/arm/v7
|
||||
tags:
|
||||
- "{{ if not .Prerelease }}latest{{ end }}"
|
||||
- "{{ if not .Prerelease }}{{ .Major }}.{{ .Minor }}.{{ .Patch }}{{ end }}"
|
||||
@@ -142,9 +136,7 @@ kos:
|
||||
- CGO_ENABLED=0
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/386
|
||||
- linux/arm64
|
||||
- linux/arm/v7
|
||||
tags:
|
||||
- "{{ if not .Prerelease }}latest-debug{{ end }}"
|
||||
- "{{ if not .Prerelease }}{{ .Major }}.{{ .Minor }}.{{ .Patch }}-debug{{ end }}"
|
||||
|
||||
45
CHANGELOG.md
45
CHANGELOG.md
@@ -2,6 +2,8 @@
|
||||
|
||||
## Next
|
||||
|
||||
**Minimum supported Tailscale client version: v1.64.0**
|
||||
|
||||
### Database integrity improvements
|
||||
|
||||
This release includes a significant database migration that addresses longstanding
|
||||
@@ -41,47 +43,8 @@ systemctl start headscale
|
||||
|
||||
### BREAKING
|
||||
|
||||
- **CLI: Remove deprecated flags**
|
||||
- `--identifier` flag removed - use `--node` or `--user` instead
|
||||
- `--namespace` flag removed - use `--user` instead
|
||||
|
||||
**Command changes:**
|
||||
```bash
|
||||
# Before
|
||||
headscale nodes expire --identifier 123
|
||||
headscale nodes rename --identifier 123 new-name
|
||||
headscale nodes delete --identifier 123
|
||||
headscale nodes move --identifier 123 --user 456
|
||||
headscale nodes list-routes --identifier 123
|
||||
|
||||
# After
|
||||
headscale nodes expire --node 123
|
||||
headscale nodes rename --node 123 new-name
|
||||
headscale nodes delete --node 123
|
||||
headscale nodes move --node 123 --user 456
|
||||
headscale nodes list-routes --node 123
|
||||
|
||||
# Before
|
||||
headscale users destroy --identifier 123
|
||||
headscale users rename --identifier 123 --new-name john
|
||||
headscale users list --identifier 123
|
||||
|
||||
# After
|
||||
headscale users destroy --user 123
|
||||
headscale users rename --user 123 --new-name john
|
||||
headscale users list --user 123
|
||||
|
||||
# Before
|
||||
headscale nodes register --namespace myuser nodekey
|
||||
headscale nodes list --namespace myuser
|
||||
headscale preauthkeys create --namespace myuser
|
||||
|
||||
# After
|
||||
headscale nodes register --user myuser nodekey
|
||||
headscale nodes list --user myuser
|
||||
headscale preauthkeys create --user myuser
|
||||
```
|
||||
|
||||
- Remove support for 32-bit binaries
|
||||
[#2692](https://github.com/juanfont/headscale/pull/2692)
|
||||
- Policy: Zero or empty destination port is no longer allowed
|
||||
[#2606](https://github.com/juanfont/headscale/pull/2606)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,201 +0,0 @@
|
||||
# CLI Standardization Summary
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Command Naming Standardization
|
||||
- **Fixed**: `backfillips` → `backfill-ips` (with backward compat alias)
|
||||
- **Fixed**: `dumpConfig` → `dump-config` (with backward compat alias)
|
||||
- **Result**: All commands now use kebab-case consistently
|
||||
|
||||
### 2. Flag Standardization
|
||||
|
||||
#### Node Commands
|
||||
- **Added**: `--node` flag as primary way to specify nodes
|
||||
- **Deprecated**: `--identifier` flag (hidden, marked deprecated)
|
||||
- **Backward Compatible**: Both flags work, `--identifier` shows deprecation warning
|
||||
- **Smart Lookup Ready**: `--node` accepts strings for future name/hostname/IP lookup
|
||||
|
||||
#### User Commands
|
||||
- **Updated**: User identification flow prepared for `--user` flag
|
||||
- **Maintained**: Existing `--name` and `--identifier` flags for backward compatibility
|
||||
|
||||
### 3. Description Consistency
|
||||
- **Fixed**: "Api" → "API" throughout
|
||||
- **Fixed**: Capitalization consistency in short descriptions
|
||||
- **Fixed**: Removed unnecessary periods from short descriptions
|
||||
- **Standardized**: "Handle/Manage the X of Headscale" pattern
|
||||
|
||||
### 4. Type Consistency
|
||||
- **Standardized**: Node IDs use `uint64` consistently
|
||||
- **Maintained**: Backward compatibility with existing flag types
|
||||
|
||||
## Current Status
|
||||
|
||||
### ✅ Completed
|
||||
- Command naming (kebab-case)
|
||||
- Flag deprecation and aliasing
|
||||
- Description standardization
|
||||
- Backward compatibility preservation
|
||||
- Helper functions for flag processing
|
||||
- **SMART LOOKUP IMPLEMENTATION**:
|
||||
- Enhanced `ListNodesRequest` proto with ID, name, hostname, IP filters
|
||||
- Implemented smart filtering in `ListNodes` gRPC method
|
||||
- Added CLI smart lookup functions for nodes and users
|
||||
- Single match validation with helpful error messages
|
||||
- Automatic detection: ID (numeric) vs IP vs name/hostname/email
|
||||
|
||||
### ✅ Smart Lookup Features
|
||||
- **Node Lookup**: By ID, hostname, or IP address
|
||||
- **User Lookup**: By ID, username, or email address
|
||||
- **Single Match Enforcement**: Errors if 0 or >1 matches found
|
||||
- **Helpful Error Messages**: Shows all matches when ambiguous
|
||||
- **Full Backward Compatibility**: All existing flags still work
|
||||
- **Enhanced List Commands**: Both `nodes list` and `users list` support all filter types
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
**None.** All changes maintain full backward compatibility through flag aliases and deprecation warnings.
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Smart Lookup Algorithm
|
||||
|
||||
1. **Input Detection**:
|
||||
```go
|
||||
if numeric && > 0 -> treat as ID
|
||||
else if contains "@" -> treat as email (users only)
|
||||
else if valid IP address -> treat as IP (nodes only)
|
||||
else -> treat as name/hostname
|
||||
```
|
||||
|
||||
2. **gRPC Filtering**:
|
||||
- Uses enhanced `ListNodes`/`ListUsers` with specific filters
|
||||
- Server-side filtering for optimal performance
|
||||
- Single transaction per lookup
|
||||
|
||||
3. **Match Validation**:
|
||||
- Exactly 1 match: Return ID
|
||||
- 0 matches: Error with "not found" message
|
||||
- >1 matches: Error listing all matches for disambiguation
|
||||
|
||||
### Enhanced Proto Definitions
|
||||
|
||||
```protobuf
|
||||
message ListNodesRequest {
|
||||
string user = 1; // existing
|
||||
uint64 id = 2; // new: filter by ID
|
||||
string name = 3; // new: filter by hostname
|
||||
string hostname = 4; // new: alias for name
|
||||
repeated string ip_addresses = 5; // new: filter by IPs
|
||||
}
|
||||
```
|
||||
|
||||
### Future Enhancements
|
||||
|
||||
- **Fuzzy Matching**: Partial name matching with confirmation
|
||||
- **Recently Used**: Cache recently accessed nodes/users
|
||||
- **Tab Completion**: Shell completion for names/hostnames
|
||||
- **Bulk Operations**: Multi-select with pattern matching
|
||||
|
||||
## Migration Path for Users
|
||||
|
||||
### Now Available (Current Release)
|
||||
```bash
|
||||
# Old way (still works, shows deprecation warning)
|
||||
headscale nodes expire --identifier 123
|
||||
|
||||
# New way with smart lookup:
|
||||
headscale nodes expire --node 123 # by ID
|
||||
headscale nodes expire --node "my-laptop" # by hostname
|
||||
headscale nodes expire --node "100.64.0.1" # by Tailscale IP
|
||||
headscale nodes expire --node "192.168.1.100" # by real IP
|
||||
|
||||
# User operations:
|
||||
headscale users destroy --user 123 # by ID
|
||||
headscale users destroy --user "alice" # by username
|
||||
headscale users destroy --user "alice@company.com" # by email
|
||||
|
||||
# Enhanced list commands with filtering:
|
||||
headscale nodes list --node "laptop" # filter nodes by name
|
||||
headscale nodes list --ip "100.64.0.1" # filter nodes by IP
|
||||
headscale nodes list --user "alice" # filter nodes by user
|
||||
headscale users list --user "alice" # smart lookup user
|
||||
headscale users list --email "@company.com" # filter by email domain
|
||||
headscale users list --name "alice" # filter by exact name
|
||||
|
||||
# Error handling examples:
|
||||
headscale nodes expire --node "laptop"
|
||||
# Error: multiple nodes found matching 'laptop': ID=1 name=laptop-alice, ID=2 name=laptop-bob
|
||||
|
||||
headscale nodes expire --node "nonexistent"
|
||||
# Error: no node found matching 'nonexistent'
|
||||
```
|
||||
|
||||
## Command Structure Overview
|
||||
|
||||
```
|
||||
headscale [global-flags] <command> [command-flags] <subcommand> [subcommand-flags] [args]
|
||||
|
||||
Global Flags:
|
||||
--config, -c config file path
|
||||
--output, -o output format (json, yaml, json-line)
|
||||
--force disable prompts
|
||||
|
||||
Commands:
|
||||
├── serve
|
||||
├── version
|
||||
├── config-test
|
||||
├── dump-config (alias: dumpConfig)
|
||||
├── mockoidc
|
||||
├── generate/
|
||||
│ └── private-key
|
||||
├── nodes/
|
||||
│ ├── list (--user, --tags, --columns)
|
||||
│ ├── register (--user, --key)
|
||||
│ ├── list-routes (--node)
|
||||
│ ├── expire (--node)
|
||||
│ ├── rename (--node) <new-name>
|
||||
│ ├── delete (--node)
|
||||
│ ├── move (--node, --user)
|
||||
│ ├── tag (--node, --tags)
|
||||
│ ├── approve-routes (--node, --routes)
|
||||
│ └── backfill-ips (alias: backfillips)
|
||||
├── users/
|
||||
│ ├── create <name> (--display-name, --email, --picture-url)
|
||||
│ ├── list (--user, --name, --email, --columns)
|
||||
│ ├── destroy (--user|--name|--identifier)
|
||||
│ └── rename (--user|--name|--identifier, --new-name)
|
||||
├── apikeys/
|
||||
│ ├── list
|
||||
│ ├── create (--expiration)
|
||||
│ ├── expire (--prefix)
|
||||
│ └── delete (--prefix)
|
||||
├── preauthkeys/
|
||||
│ ├── list (--user)
|
||||
│ ├── create (--user, --reusable, --ephemeral, --expiration, --tags)
|
||||
│ └── expire (--user) <key>
|
||||
├── policy/
|
||||
│ ├── get
|
||||
│ ├── set (--file)
|
||||
│ └── check (--file)
|
||||
└── debug/
|
||||
└── create-node (--name, --user, --key, --route)
|
||||
```
|
||||
|
||||
## Deprecated Flags
|
||||
|
||||
All deprecated flags continue to work but show warnings:
|
||||
|
||||
- `--identifier` → use `--node` (for node commands) or `--user` (for user commands)
|
||||
- `--namespace` → use `--user` (already implemented)
|
||||
- `dumpConfig` → use `dump-config`
|
||||
- `backfillips` → use `backfill-ips`
|
||||
|
||||
## Error Handling
|
||||
|
||||
Improved error messages provide clear guidance:
|
||||
```
|
||||
Error: node specifier must be a numeric ID (smart lookup by name/hostname/IP not yet implemented)
|
||||
Error: --node flag is required
|
||||
Error: --user flag is required
|
||||
```
|
||||
@@ -13,14 +13,18 @@ RUN apt-get update \
|
||||
&& apt-get clean
|
||||
RUN mkdir -p /var/run/headscale
|
||||
|
||||
# Install delve debugger
|
||||
RUN go install github.com/go-delve/delve/cmd/dlv@latest
|
||||
|
||||
COPY go.mod go.sum /go/src/headscale/
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go install -a ./cmd/headscale && test -e /go/bin/headscale
|
||||
# Build debug binary with debug symbols for delve
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -gcflags="all=-N -l" -o /go/bin/headscale ./cmd/headscale
|
||||
|
||||
# Need to reset the entrypoint or everything will run as a busybox script
|
||||
ENTRYPOINT []
|
||||
EXPOSE 8080/tcp
|
||||
CMD ["headscale"]
|
||||
EXPOSE 8080/tcp 40000/tcp
|
||||
CMD ["/go/bin/dlv", "--listen=0.0.0.0:40000", "--headless=true", "--api-version=2", "--accept-multiclient", "exec", "/go/bin/headscale", "--"]
|
||||
|
||||
7
Makefile
7
Makefile
@@ -87,10 +87,9 @@ lint-proto: check-deps $(PROTO_SOURCES)
|
||||
|
||||
# Code generation
|
||||
.PHONY: generate
|
||||
generate: check-deps $(PROTO_SOURCES)
|
||||
@echo "Generating code from Protocol Buffers..."
|
||||
rm -rf gen
|
||||
buf generate proto
|
||||
generate: check-deps
|
||||
@echo "Generating code..."
|
||||
go generate ./...
|
||||
|
||||
# Clean targets
|
||||
.PHONY: clean
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -15,6 +14,11 @@ import (
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
// 90 days.
|
||||
DefaultAPIKeyExpiry = "90d"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(apiKeysCmd)
|
||||
apiKeysCmd.AddCommand(listAPIKeys)
|
||||
@@ -39,80 +43,75 @@ func init() {
|
||||
|
||||
var apiKeysCmd = &cobra.Command{
|
||||
Use: "apikeys",
|
||||
Short: "Handle the API keys in Headscale",
|
||||
Short: "Handle the Api keys in Headscale",
|
||||
Aliases: []string{"apikey", "api"},
|
||||
}
|
||||
|
||||
var listAPIKeys = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List the API keys for Headscale",
|
||||
Short: "List the Api keys for headscale",
|
||||
Aliases: []string{"ls", "show"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListApiKeysRequest{}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
response, err := client.ListApiKeys(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
request := &v1.ListApiKeysRequest{}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetApiKeys(), "", output)
|
||||
return nil
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{
|
||||
{"ID", "Prefix", "Expiration", "Created"},
|
||||
}
|
||||
for _, key := range response.GetApiKeys() {
|
||||
expiration := "-"
|
||||
|
||||
if key.GetExpiration() != nil {
|
||||
expiration = ColourTime(key.GetExpiration().AsTime())
|
||||
}
|
||||
|
||||
tableData = append(tableData, []string{
|
||||
strconv.FormatUint(key.GetId(), util.Base10),
|
||||
key.GetPrefix(),
|
||||
expiration,
|
||||
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
||||
})
|
||||
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
response, err := client.ListApiKeys(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetApiKeys(), "", output)
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{
|
||||
{"ID", "Prefix", "Expiration", "Created"},
|
||||
}
|
||||
for _, key := range response.GetApiKeys() {
|
||||
expiration := "-"
|
||||
|
||||
if key.GetExpiration() != nil {
|
||||
expiration = ColourTime(key.GetExpiration().AsTime())
|
||||
}
|
||||
|
||||
tableData = append(tableData, []string{
|
||||
strconv.FormatUint(key.GetId(), util.Base10),
|
||||
key.GetPrefix(),
|
||||
expiration,
|
||||
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
||||
})
|
||||
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var createAPIKeyCmd = &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new API key",
|
||||
Short: "Creates a new Api key",
|
||||
Long: `
|
||||
Creates a new Api key, the Api key is only visible on creation
|
||||
and cannot be retrieved again.
|
||||
If you loose a key, create a new one and revoke (expire) the old one.`,
|
||||
Aliases: []string{"c", "new"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
request := &v1.CreateApiKeyRequest{}
|
||||
|
||||
@@ -125,101 +124,99 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
|
||||
fmt.Sprintf("Could not parse duration: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := time.Now().UTC().Add(time.Duration(duration))
|
||||
|
||||
request.Expiration = timestamppb.New(expiration)
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
response, err := client.CreateApiKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.CreateApiKey(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
|
||||
},
|
||||
}
|
||||
|
||||
var expireAPIKeyCmd = &cobra.Command{
|
||||
Use: "expire",
|
||||
Short: "Expire an API key",
|
||||
Short: "Expire an ApiKey",
|
||||
Aliases: []string{"revoke", "exp", "e"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
prefix, err := cmd.Flags().GetString("prefix")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting prefix from CLI flag: %s", err), output)
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ExpireApiKeyRequest{
|
||||
Prefix: prefix,
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
response, err := client.ExpireApiKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
request := &v1.ExpireApiKeyRequest{
|
||||
Prefix: prefix,
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key expired", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.ExpireApiKey(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key expired", output)
|
||||
},
|
||||
}
|
||||
|
||||
var deleteAPIKeyCmd = &cobra.Command{
|
||||
Use: "delete",
|
||||
Short: "Delete an API key",
|
||||
Short: "Delete an ApiKey",
|
||||
Aliases: []string{"remove", "del"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
prefix, err := cmd.Flags().GetString("prefix")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting prefix from CLI flag: %s", err), output)
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.DeleteApiKeyRequest{
|
||||
Prefix: prefix,
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
response, err := client.DeleteApiKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
request := &v1.DeleteApiKeyRequest{
|
||||
Prefix: prefix,
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key deleted", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.DeleteApiKey(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key deleted", output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
)
|
||||
|
||||
// WithClient handles gRPC client setup and cleanup, calls fn with client and context
|
||||
func WithClient(fn func(context.Context, v1.HeadscaleServiceClient) error) error {
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
return fn(ctx, client)
|
||||
}
|
||||
@@ -11,8 +11,8 @@ func init() {
|
||||
|
||||
var configTestCmd = &cobra.Command{
|
||||
Use: "configtest",
|
||||
Short: "Test the configuration",
|
||||
Long: "Run a test of the configuration and exit",
|
||||
Short: "Test the configuration.",
|
||||
Long: "Run a test of the configuration and exit.",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
_, err := newHeadscaleServerWithConfig()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfigTestCommand(t *testing.T) {
|
||||
// Test that the configtest command exists and is properly configured
|
||||
assert.NotNil(t, configTestCmd)
|
||||
assert.Equal(t, "configtest", configTestCmd.Use)
|
||||
assert.Equal(t, "Test the configuration.", configTestCmd.Short)
|
||||
assert.Equal(t, "Run a test of the configuration and exit.", configTestCmd.Long)
|
||||
assert.NotNil(t, configTestCmd.Run)
|
||||
}
|
||||
|
||||
func TestConfigTestCommandInRootCommand(t *testing.T) {
|
||||
// Test that configtest is available as a subcommand of root
|
||||
cmd, _, err := rootCmd.Find([]string{"configtest"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "configtest", cmd.Name())
|
||||
assert.Equal(t, configTestCmd, cmd)
|
||||
}
|
||||
|
||||
func TestConfigTestCommandHelp(t *testing.T) {
|
||||
// Test that the command has proper help text
|
||||
assert.NotEmpty(t, configTestCmd.Short)
|
||||
assert.NotEmpty(t, configTestCmd.Long)
|
||||
assert.Contains(t, configTestCmd.Short, "configuration")
|
||||
assert.Contains(t, configTestCmd.Long, "test")
|
||||
assert.Contains(t, configTestCmd.Long, "configuration")
|
||||
}
|
||||
|
||||
// Note: We can't easily test the actual execution of configtest because:
|
||||
// 1. It depends on configuration files being present
|
||||
// 2. It calls log.Fatal() which would exit the test process
|
||||
// 3. It tries to initialize a full Headscale server
|
||||
//
|
||||
// In a real refactor, we would:
|
||||
// 1. Extract the configuration validation logic to a testable function
|
||||
// 2. Return errors instead of calling log.Fatal()
|
||||
// 3. Accept configuration as a parameter instead of loading from global state
|
||||
//
|
||||
// For now, we test the command structure and that it's properly wired up.
|
||||
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
@@ -15,6 +14,11 @@ const (
|
||||
errPreAuthKeyMalformed = Error("key is malformed. expected 64 hex characters with `nodekey` prefix")
|
||||
)
|
||||
|
||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string { return string(e) }
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
@@ -25,6 +29,11 @@ func init() {
|
||||
}
|
||||
createNodeCmd.Flags().StringP("user", "u", "", "User")
|
||||
|
||||
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
||||
createNodeNamespaceFlag := createNodeCmd.Flags().Lookup("namespace")
|
||||
createNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
createNodeNamespaceFlag.Hidden = true
|
||||
|
||||
err = createNodeCmd.MarkFlagRequired("user")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
@@ -50,14 +59,17 @@ var createNodeCmd = &cobra.Command{
|
||||
Use: "create-node",
|
||||
Short: "Create a node that can be registered with `nodes register <>` command",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
name, err := cmd.Flags().GetString("name")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
@@ -65,7 +77,6 @@ var createNodeCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error getting node from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
registrationID, err := cmd.Flags().GetString("key")
|
||||
@@ -75,7 +86,6 @@ var createNodeCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error getting key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = types.RegistrationIDFromString(registrationID)
|
||||
@@ -85,7 +95,6 @@ var createNodeCmd = &cobra.Command{
|
||||
fmt.Sprintf("Failed to parse machine key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := cmd.Flags().GetStringSlice("route")
|
||||
@@ -95,32 +104,24 @@ var createNodeCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error getting routes from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.DebugCreateNodeRequest{
|
||||
Key: registrationID,
|
||||
Name: name,
|
||||
User: user,
|
||||
Routes: routes,
|
||||
}
|
||||
request := &v1.DebugCreateNodeRequest{
|
||||
Key: registrationID,
|
||||
Name: name,
|
||||
User: user,
|
||||
Routes: routes,
|
||||
}
|
||||
|
||||
response, err := client.DebugCreateNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot create node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node created", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.DebugCreateNode(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot create node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node created", output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDebugCommand(t *testing.T) {
|
||||
// Test that the debug command exists and is properly configured
|
||||
assert.NotNil(t, debugCmd)
|
||||
assert.Equal(t, "debug", debugCmd.Use)
|
||||
assert.Equal(t, "debug and testing commands", debugCmd.Short)
|
||||
assert.Equal(t, "debug contains extra commands used for debugging and testing headscale", debugCmd.Long)
|
||||
}
|
||||
|
||||
func TestDebugCommandInRootCommand(t *testing.T) {
|
||||
// Test that debug is available as a subcommand of root
|
||||
cmd, _, err := rootCmd.Find([]string{"debug"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "debug", cmd.Name())
|
||||
assert.Equal(t, debugCmd, cmd)
|
||||
}
|
||||
|
||||
func TestCreateNodeCommand(t *testing.T) {
|
||||
// Test that the create-node command exists and is properly configured
|
||||
assert.NotNil(t, createNodeCmd)
|
||||
assert.Equal(t, "create-node", createNodeCmd.Use)
|
||||
assert.Equal(t, "Create a node that can be registered with `nodes register <>` command", createNodeCmd.Short)
|
||||
assert.NotNil(t, createNodeCmd.Run)
|
||||
}
|
||||
|
||||
func TestCreateNodeCommandInDebugCommand(t *testing.T) {
|
||||
// Test that create-node is available as a subcommand of debug
|
||||
cmd, _, err := rootCmd.Find([]string{"debug", "create-node"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "create-node", cmd.Name())
|
||||
assert.Equal(t, createNodeCmd, cmd)
|
||||
}
|
||||
|
||||
func TestCreateNodeCommandFlags(t *testing.T) {
|
||||
// Test that create-node has the required flags
|
||||
|
||||
// Test name flag
|
||||
nameFlag := createNodeCmd.Flags().Lookup("name")
|
||||
assert.NotNil(t, nameFlag)
|
||||
assert.Equal(t, "", nameFlag.Shorthand) // No shorthand for name
|
||||
assert.Equal(t, "", nameFlag.DefValue)
|
||||
|
||||
// Test user flag
|
||||
userFlag := createNodeCmd.Flags().Lookup("user")
|
||||
assert.NotNil(t, userFlag)
|
||||
assert.Equal(t, "u", userFlag.Shorthand)
|
||||
|
||||
// Test key flag
|
||||
keyFlag := createNodeCmd.Flags().Lookup("key")
|
||||
assert.NotNil(t, keyFlag)
|
||||
assert.Equal(t, "k", keyFlag.Shorthand)
|
||||
|
||||
// Test route flag
|
||||
routeFlag := createNodeCmd.Flags().Lookup("route")
|
||||
assert.NotNil(t, routeFlag)
|
||||
assert.Equal(t, "r", routeFlag.Shorthand)
|
||||
|
||||
}
|
||||
|
||||
func TestCreateNodeCommandRequiredFlags(t *testing.T) {
|
||||
// Test that required flags are marked as required
|
||||
// We can't easily test the actual requirement enforcement without executing the command
|
||||
// But we can test that the flags exist and have the expected properties
|
||||
|
||||
// These flags should be required based on the init() function
|
||||
requiredFlags := []string{"name", "user", "key"}
|
||||
|
||||
for _, flagName := range requiredFlags {
|
||||
flag := createNodeCmd.Flags().Lookup(flagName)
|
||||
assert.NotNil(t, flag, "Required flag %s should exist", flagName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorType(t *testing.T) {
|
||||
// Test the Error type implementation
|
||||
err := errPreAuthKeyMalformed
|
||||
assert.Equal(t, "key is malformed. expected 64 hex characters with `nodekey` prefix", err.Error())
|
||||
assert.Equal(t, "key is malformed. expected 64 hex characters with `nodekey` prefix", string(err))
|
||||
|
||||
// Test that it implements the error interface
|
||||
var genericErr error = err
|
||||
assert.Equal(t, "key is malformed. expected 64 hex characters with `nodekey` prefix", genericErr.Error())
|
||||
}
|
||||
|
||||
func TestErrorConstants(t *testing.T) {
|
||||
// Test that error constants are defined properly
|
||||
assert.Equal(t, Error("key is malformed. expected 64 hex characters with `nodekey` prefix"), errPreAuthKeyMalformed)
|
||||
}
|
||||
|
||||
func TestDebugCommandStructure(t *testing.T) {
|
||||
// Test that debug has create-node as a subcommand
|
||||
found := false
|
||||
for _, subcmd := range debugCmd.Commands() {
|
||||
if subcmd.Name() == "create-node" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "create-node should be a subcommand of debug")
|
||||
}
|
||||
|
||||
func TestCreateNodeCommandHelp(t *testing.T) {
|
||||
// Test that the command has proper help text
|
||||
assert.NotEmpty(t, createNodeCmd.Short)
|
||||
assert.Contains(t, createNodeCmd.Short, "Create a node")
|
||||
assert.Contains(t, createNodeCmd.Short, "nodes register")
|
||||
}
|
||||
|
||||
func TestCreateNodeCommandFlagDescriptions(t *testing.T) {
|
||||
// Test that flags have appropriate usage descriptions
|
||||
nameFlag := createNodeCmd.Flags().Lookup("name")
|
||||
assert.Equal(t, "Name", nameFlag.Usage)
|
||||
|
||||
userFlag := createNodeCmd.Flags().Lookup("user")
|
||||
assert.Equal(t, "User", userFlag.Usage)
|
||||
|
||||
keyFlag := createNodeCmd.Flags().Lookup("key")
|
||||
assert.Equal(t, "Key", keyFlag.Usage)
|
||||
|
||||
routeFlag := createNodeCmd.Flags().Lookup("route")
|
||||
assert.Contains(t, routeFlag.Usage, "routes to advertise")
|
||||
|
||||
}
|
||||
|
||||
// Note: We can't easily test the actual execution of create-node because:
|
||||
// 1. It depends on gRPC client configuration
|
||||
// 2. It calls SuccessOutput/ErrorOutput which exit the process
|
||||
// 3. It requires valid registration keys and user setup
|
||||
//
|
||||
// In a real refactor, we would:
|
||||
// 1. Extract the business logic to testable functions
|
||||
// 2. Use dependency injection for the gRPC client
|
||||
// 3. Return errors instead of calling ErrorOutput/SuccessOutput
|
||||
// 4. Add validation functions that can be tested independently
|
||||
//
|
||||
// For now, we test the command structure and flag configuration.
|
||||
@@ -12,10 +12,9 @@ func init() {
|
||||
}
|
||||
|
||||
var dumpConfigCmd = &cobra.Command{
|
||||
Use: "dump-config",
|
||||
Short: "Dump current config to /etc/headscale/config.dump.yaml, integration test only",
|
||||
Aliases: []string{"dumpConfig"},
|
||||
Hidden: true,
|
||||
Use: "dumpConfig",
|
||||
Short: "dump current config to /etc/headscale/config.dump.yaml, integration test only",
|
||||
Hidden: true,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -22,7 +22,7 @@ var generatePrivateKeyCmd = &cobra.Command{
|
||||
Use: "private-key",
|
||||
Short: "Generate a private key for the headscale server",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machineKeyStr, err := machineKey.MarshalText()
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestGenerateCommand(t *testing.T) {
|
||||
// Test that the generate command exists and shows help
|
||||
cmd := &cobra.Command{
|
||||
Use: "headscale",
|
||||
Short: "headscale - a Tailscale control server",
|
||||
}
|
||||
|
||||
cmd.AddCommand(generateCmd)
|
||||
|
||||
out := new(bytes.Buffer)
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(out)
|
||||
cmd.SetArgs([]string{"generate", "--help"})
|
||||
|
||||
err := cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
|
||||
outStr := out.String()
|
||||
assert.Contains(t, outStr, "Generate commands")
|
||||
assert.Contains(t, outStr, "private-key")
|
||||
assert.Contains(t, outStr, "Aliases:")
|
||||
assert.Contains(t, outStr, "gen")
|
||||
}
|
||||
|
||||
func TestGenerateCommandAlias(t *testing.T) {
|
||||
// Test that the "gen" alias works
|
||||
cmd := &cobra.Command{
|
||||
Use: "headscale",
|
||||
Short: "headscale - a Tailscale control server",
|
||||
}
|
||||
|
||||
cmd.AddCommand(generateCmd)
|
||||
|
||||
out := new(bytes.Buffer)
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(out)
|
||||
cmd.SetArgs([]string{"gen", "--help"})
|
||||
|
||||
err := cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
|
||||
outStr := out.String()
|
||||
assert.Contains(t, outStr, "Generate commands")
|
||||
}
|
||||
|
||||
func TestGeneratePrivateKeyCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectJSON bool
|
||||
expectYAML bool
|
||||
}{
|
||||
{
|
||||
name: "default output",
|
||||
args: []string{"generate", "private-key"},
|
||||
expectJSON: false,
|
||||
expectYAML: false,
|
||||
},
|
||||
{
|
||||
name: "json output",
|
||||
args: []string{"generate", "private-key", "--output", "json"},
|
||||
expectJSON: true,
|
||||
expectYAML: false,
|
||||
},
|
||||
{
|
||||
name: "yaml output",
|
||||
args: []string{"generate", "private-key", "--output", "yaml"},
|
||||
expectJSON: false,
|
||||
expectYAML: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Note: This command calls SuccessOutput which exits the process
|
||||
// We can't test the actual execution easily without mocking
|
||||
// Instead, we test the command structure and that it exists
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "headscale",
|
||||
Short: "headscale - a Tailscale control server",
|
||||
}
|
||||
|
||||
cmd.AddCommand(generateCmd)
|
||||
cmd.PersistentFlags().StringP("output", "o", "", "Output format")
|
||||
|
||||
// Test that the command exists and can be found
|
||||
privateKeyCmd, _, err := cmd.Find([]string{"generate", "private-key"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "private-key", privateKeyCmd.Name())
|
||||
assert.Equal(t, "Generate a private key for the headscale server", privateKeyCmd.Short)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePrivateKeyHelp(t *testing.T) {
|
||||
cmd := &cobra.Command{
|
||||
Use: "headscale",
|
||||
Short: "headscale - a Tailscale control server",
|
||||
}
|
||||
|
||||
cmd.AddCommand(generateCmd)
|
||||
|
||||
out := new(bytes.Buffer)
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(out)
|
||||
cmd.SetArgs([]string{"generate", "private-key", "--help"})
|
||||
|
||||
err := cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
|
||||
outStr := out.String()
|
||||
assert.Contains(t, outStr, "Generate a private key for the headscale server")
|
||||
assert.Contains(t, outStr, "Usage:")
|
||||
}
|
||||
|
||||
// Test the key generation logic in isolation (without SuccessOutput/ErrorOutput)
|
||||
func TestPrivateKeyGeneration(t *testing.T) {
|
||||
// We can't easily test the full command because it calls SuccessOutput which exits
|
||||
// But we can test that the key generation produces valid output format
|
||||
|
||||
// This is testing the core logic that would be in the command
|
||||
// In a real refactor, we'd extract this to a testable function
|
||||
|
||||
// For now, we can test that the command structure is correct
|
||||
assert.NotNil(t, generatePrivateKeyCmd)
|
||||
assert.Equal(t, "private-key", generatePrivateKeyCmd.Use)
|
||||
assert.Equal(t, "Generate a private key for the headscale server", generatePrivateKeyCmd.Short)
|
||||
assert.NotNil(t, generatePrivateKeyCmd.Run)
|
||||
}
|
||||
|
||||
func TestGenerateCommandStructure(t *testing.T) {
|
||||
// Test the command hierarchy
|
||||
assert.Equal(t, "generate", generateCmd.Use)
|
||||
assert.Equal(t, "Generate commands", generateCmd.Short)
|
||||
assert.Contains(t, generateCmd.Aliases, "gen")
|
||||
|
||||
// Test that private-key is a subcommand
|
||||
found := false
|
||||
for _, subcmd := range generateCmd.Commands() {
|
||||
if subcmd.Name() == "private-key" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "private-key should be a subcommand of generate")
|
||||
}
|
||||
|
||||
// Helper function to test output formats (would be used if we refactored the command)
|
||||
func validatePrivateKeyOutput(t *testing.T, output string, format string) {
|
||||
switch format {
|
||||
case "json":
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(output), &result)
|
||||
require.NoError(t, err, "Output should be valid JSON")
|
||||
|
||||
privateKey, exists := result["private_key"]
|
||||
require.True(t, exists, "JSON should contain private_key field")
|
||||
|
||||
keyStr, ok := privateKey.(string)
|
||||
require.True(t, ok, "private_key should be a string")
|
||||
require.NotEmpty(t, keyStr, "private_key should not be empty")
|
||||
|
||||
// Basic validation that it looks like a machine key
|
||||
assert.True(t, strings.HasPrefix(keyStr, "mkey:"), "Machine key should start with mkey:")
|
||||
|
||||
case "yaml":
|
||||
var result map[string]interface{}
|
||||
err := yaml.Unmarshal([]byte(output), &result)
|
||||
require.NoError(t, err, "Output should be valid YAML")
|
||||
|
||||
privateKey, exists := result["private_key"]
|
||||
require.True(t, exists, "YAML should contain private_key field")
|
||||
|
||||
keyStr, ok := privateKey.(string)
|
||||
require.True(t, ok, "private_key should be a string")
|
||||
require.NotEmpty(t, keyStr, "private_key should not be empty")
|
||||
|
||||
assert.True(t, strings.HasPrefix(keyStr, "mkey:"), "Machine key should start with mkey:")
|
||||
|
||||
default:
|
||||
// Default format should just be the key itself
|
||||
assert.True(t, strings.HasPrefix(output, "mkey:"), "Default output should be the machine key")
|
||||
assert.NotContains(t, output, "{", "Default output should not contain JSON")
|
||||
assert.NotContains(t, output, "private_key:", "Default output should not contain YAML structure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyOutputFormats(t *testing.T) {
|
||||
// Test cases for different output formats
|
||||
// These test the validation logic we would use after refactoring
|
||||
|
||||
tests := []struct {
|
||||
format string
|
||||
sample string
|
||||
}{
|
||||
{
|
||||
format: "json",
|
||||
sample: `{"private_key": "mkey:abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234"}`,
|
||||
},
|
||||
{
|
||||
format: "yaml",
|
||||
sample: "private_key: mkey:abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234\n",
|
||||
},
|
||||
{
|
||||
format: "",
|
||||
sample: "mkey:abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("format_"+tt.format, func(t *testing.T) {
|
||||
validatePrivateKeyOutput(t, tt.sample, tt.format)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,11 +15,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string { return string(e) }
|
||||
|
||||
const (
|
||||
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
||||
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
@@ -22,23 +21,25 @@ import (
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(nodeCmd)
|
||||
// User filtering
|
||||
listNodesCmd.Flags().StringP("user", "u", "", "Filter by user")
|
||||
// Node filtering
|
||||
listNodesCmd.Flags().StringP("node", "", "", "Filter by node (ID, name, hostname, or IP)")
|
||||
listNodesCmd.Flags().Uint64P("id", "", 0, "Filter by node ID")
|
||||
listNodesCmd.Flags().StringP("name", "", "", "Filter by node hostname")
|
||||
listNodesCmd.Flags().StringP("ip", "", "", "Filter by node IP address")
|
||||
// Display options
|
||||
listNodesCmd.Flags().BoolP("tags", "t", false, "Show tags")
|
||||
listNodesCmd.Flags().String("columns", "", "Comma-separated list of columns to display")
|
||||
|
||||
listNodesCmd.Flags().StringP("namespace", "n", "", "User")
|
||||
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
|
||||
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
listNodesNamespaceFlag.Hidden = true
|
||||
nodeCmd.AddCommand(listNodesCmd)
|
||||
|
||||
listNodeRoutesCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
nodeCmd.AddCommand(listNodeRoutesCmd)
|
||||
|
||||
registerNodeCmd.Flags().StringP("user", "u", "", "User")
|
||||
|
||||
registerNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
||||
registerNodeNamespaceFlag := registerNodeCmd.Flags().Lookup("namespace")
|
||||
registerNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
registerNodeNamespaceFlag.Hidden = true
|
||||
|
||||
err := registerNodeCmd.MarkFlagRequired("user")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
@@ -50,43 +51,54 @@ func init() {
|
||||
}
|
||||
nodeCmd.AddCommand(registerNodeCmd)
|
||||
|
||||
expireNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
err = expireNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
nodeCmd.AddCommand(expireNodeCmd)
|
||||
|
||||
renameNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
err = renameNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
nodeCmd.AddCommand(renameNodeCmd)
|
||||
|
||||
deleteNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
err = deleteNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
nodeCmd.AddCommand(deleteNodeCmd)
|
||||
|
||||
moveNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
moveNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
|
||||
err = moveNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
moveNodeCmd.Flags().StringP("user", "u", "", "New user (ID, name, or email)")
|
||||
moveNodeCmd.Flags().String("name", "", "New username")
|
||||
moveNodeCmd.Flags().Uint64P("user", "u", 0, "New user")
|
||||
|
||||
// One of --user or --name is required (checked in GetUserIdentifier)
|
||||
moveNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
||||
moveNodeNamespaceFlag := moveNodeCmd.Flags().Lookup("namespace")
|
||||
moveNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
moveNodeNamespaceFlag.Hidden = true
|
||||
|
||||
err = moveNodeCmd.MarkFlagRequired("user")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
nodeCmd.AddCommand(moveNodeCmd)
|
||||
|
||||
tagCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
tagCmd.MarkFlagRequired("node")
|
||||
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
tagCmd.MarkFlagRequired("identifier")
|
||||
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
|
||||
nodeCmd.AddCommand(tagCmd)
|
||||
|
||||
approveRoutesCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
|
||||
approveRoutesCmd.MarkFlagRequired("node")
|
||||
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
|
||||
nodeCmd.AddCommand(approveRoutesCmd)
|
||||
|
||||
@@ -103,13 +115,16 @@ var registerNodeCmd = &cobra.Command{
|
||||
Use: "register",
|
||||
Short: "Registers a node to your network",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
registrationID, err := cmd.Flags().GetString("key")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
@@ -117,36 +132,28 @@ var registerNodeCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error getting node key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.RegisterNodeRequest{
|
||||
Key: registrationID,
|
||||
User: user,
|
||||
}
|
||||
request := &v1.RegisterNodeRequest{
|
||||
Key: registrationID,
|
||||
User: user,
|
||||
}
|
||||
|
||||
response, err := client.RegisterNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot register node: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
SuccessOutput(
|
||||
response.GetNode(),
|
||||
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.RegisterNode(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot register node: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(
|
||||
response.GetNode(),
|
||||
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -155,79 +162,49 @@ var listNodesCmd = &cobra.Command{
|
||||
Short: "List nodes",
|
||||
Aliases: []string{"ls", "show"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
}
|
||||
showTags, err := cmd.Flags().GetBool("tags")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListNodesRequest{}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
// Handle user filtering (existing functionality)
|
||||
if user, _ := cmd.Flags().GetString("user"); user != "" {
|
||||
request.User = user
|
||||
}
|
||||
request := &v1.ListNodesRequest{
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Handle node filtering (new functionality)
|
||||
if nodeFlag, _ := cmd.Flags().GetString("node"); nodeFlag != "" {
|
||||
// Use smart lookup to determine filter type
|
||||
if id, err := strconv.ParseUint(nodeFlag, 10, 64); err == nil && id > 0 {
|
||||
request.Id = id
|
||||
} else if isIPAddress(nodeFlag) {
|
||||
request.IpAddresses = []string{nodeFlag}
|
||||
} else {
|
||||
request.Name = nodeFlag
|
||||
}
|
||||
} else {
|
||||
// Check specific filter flags
|
||||
if id, _ := cmd.Flags().GetUint64("id"); id > 0 {
|
||||
request.Id = id
|
||||
} else if name, _ := cmd.Flags().GetString("name"); name != "" {
|
||||
request.Name = name
|
||||
} else if ip, _ := cmd.Flags().GetString("ip"); ip != "" {
|
||||
request.IpAddresses = []string{ip}
|
||||
}
|
||||
}
|
||||
|
||||
response, err := client.ListNodes(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot get nodes: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetNodes(), "", output)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get user for table display (if filtering by user)
|
||||
userFilter := request.User
|
||||
tableData, err := nodesToPtables(userFilter, showTags, response.GetNodes())
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
return err
|
||||
}
|
||||
|
||||
tableData = FilterTableColumns(cmd, tableData)
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
response, err := client.ListNodes(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot get nodes: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetNodes(), "", output)
|
||||
}
|
||||
|
||||
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -237,68 +214,63 @@ var listNodeRoutesCmd = &cobra.Command{
|
||||
Short: "List routes available on nodes",
|
||||
Aliases: []string{"lsr", "routes"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListNodesRequest{}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
response, err := client.ListNodes(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot get nodes: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
request := &v1.ListNodesRequest{}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetNodes(), "", output)
|
||||
return nil
|
||||
}
|
||||
response, err := client.ListNodes(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot get nodes: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
nodes := response.GetNodes()
|
||||
if identifier != 0 {
|
||||
for _, node := range response.GetNodes() {
|
||||
if node.GetId() == identifier {
|
||||
nodes = []*v1.Node{node}
|
||||
break
|
||||
}
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetNodes(), "", output)
|
||||
}
|
||||
|
||||
nodes := response.GetNodes()
|
||||
if identifier != 0 {
|
||||
for _, node := range response.GetNodes() {
|
||||
if node.GetId() == identifier {
|
||||
nodes = []*v1.Node{node}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool {
|
||||
return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0)
|
||||
})
|
||||
|
||||
tableData, err := nodeRoutesToPtables(nodes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
return err
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool {
|
||||
return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0)
|
||||
})
|
||||
|
||||
tableData, err := nodeRoutesToPtables(nodes)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -309,42 +281,42 @@ var expireNodeCmd = &cobra.Command{
|
||||
Long: "Expiring a node will keep the node in the database and force it to reauthenticate.",
|
||||
Aliases: []string{"logout", "exp", "e"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ExpireNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
response, err := client.ExpireNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot expire node: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
request := &v1.ExpireNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node expired", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.ExpireNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot expire node: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node expired", output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -352,48 +324,47 @@ var renameNodeCmd = &cobra.Command{
|
||||
Use: "rename NEW_NAME",
|
||||
Short: "Renames a node in your network",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
newName := ""
|
||||
if len(args) > 0 {
|
||||
newName = args[0]
|
||||
}
|
||||
request := &v1.RenameNodeRequest{
|
||||
NodeId: identifier,
|
||||
NewName: newName,
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.RenameNodeRequest{
|
||||
NodeId: identifier,
|
||||
NewName: newName,
|
||||
}
|
||||
|
||||
response, err := client.RenameNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot rename node: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node renamed", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.RenameNode(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot rename node: %s\n",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node renamed", output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -402,47 +373,49 @@ var deleteNodeCmd = &cobra.Command{
|
||||
Short: "Delete a node",
|
||||
Aliases: []string{"del"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var nodeName string
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
getRequest := &v1.GetNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
getResponse, err := client.GetNode(ctx, getRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error getting node node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
nodeName = getResponse.GetNode().GetName()
|
||||
return nil
|
||||
})
|
||||
getRequest := &v1.GetNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
|
||||
getResponse, err := client.GetNode(ctx, getRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error getting node node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
deleteRequest := &v1.DeleteNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
|
||||
confirm := false
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
if !force {
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf(
|
||||
"Do you want to remove the node %s?",
|
||||
nodeName,
|
||||
getResponse.GetNode().GetName(),
|
||||
),
|
||||
}
|
||||
err = survey.AskOne(prompt, &confirm)
|
||||
@@ -452,34 +425,26 @@ var deleteNodeCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if confirm || force {
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
deleteRequest := &v1.DeleteNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
response, err := client.DeleteNode(ctx, deleteRequest)
|
||||
if output != "" {
|
||||
SuccessOutput(response, "", output)
|
||||
|
||||
response, err := client.DeleteNode(ctx, deleteRequest)
|
||||
if output != "" {
|
||||
SuccessOutput(response, "", output)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error deleting node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
SuccessOutput(
|
||||
map[string]string{"Result": "Node deleted"},
|
||||
"Node deleted",
|
||||
output,
|
||||
)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error deleting node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
SuccessOutput(
|
||||
map[string]string{"Result": "Node deleted"},
|
||||
"Node deleted",
|
||||
output,
|
||||
)
|
||||
} else {
|
||||
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
|
||||
}
|
||||
@@ -491,71 +456,72 @@ var moveNodeCmd = &cobra.Command{
|
||||
Short: "Move node to another user",
|
||||
Aliases: []string{"mv"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := GetUserIdentifier(cmd)
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting user: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
getRequest := &v1.GetNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
_, err := client.GetNode(ctx, getRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error getting node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
getRequest := &v1.GetNodeRequest{
|
||||
NodeId: identifier,
|
||||
}
|
||||
|
||||
moveRequest := &v1.MoveNodeRequest{
|
||||
NodeId: identifier,
|
||||
User: userID,
|
||||
}
|
||||
|
||||
moveResponse, err := client.MoveNode(ctx, moveRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error moving node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output)
|
||||
return nil
|
||||
})
|
||||
_, err = client.GetNode(ctx, getRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error getting node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
moveRequest := &v1.MoveNodeRequest{
|
||||
NodeId: identifier,
|
||||
User: user,
|
||||
}
|
||||
|
||||
moveResponse, err := client.MoveNode(ctx, moveRequest)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error moving node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output)
|
||||
},
|
||||
}
|
||||
|
||||
var backfillNodeIPsCmd = &cobra.Command{
|
||||
Use: "backfill-ips",
|
||||
Short: "Backfill IPs missing from nodes",
|
||||
Aliases: []string{"backfillips"},
|
||||
Use: "backfillips",
|
||||
Short: "Backfill IPs missing from nodes",
|
||||
Long: `
|
||||
Backfill IPs can be used to add/remove IPs from nodes
|
||||
based on the current configuration of Headscale.
|
||||
@@ -570,7 +536,7 @@ it can be run to remove the IPs that should no longer
|
||||
be assigned to nodes.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
@@ -581,23 +547,22 @@ be assigned to nodes.`,
|
||||
return
|
||||
}
|
||||
if confirm {
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm})
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error backfilling IPs: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
SuccessOutput(changes, "Node IPs backfilled successfully", output)
|
||||
return nil
|
||||
})
|
||||
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm})
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error backfilling IPs: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(changes, "Node IPs backfilled successfully", output)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -640,14 +605,14 @@ func nodesToPtables(
|
||||
var lastSeenTime string
|
||||
if node.GetLastSeen() != nil {
|
||||
lastSeen = node.GetLastSeen().AsTime()
|
||||
lastSeenTime = lastSeen.Format(HeadscaleDateTimeFormat)
|
||||
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
var expiry time.Time
|
||||
var expiryTime string
|
||||
if node.GetExpiry() != nil {
|
||||
expiry = node.GetExpiry().AsTime()
|
||||
expiryTime = expiry.Format(HeadscaleDateTimeFormat)
|
||||
expiryTime = expiry.Format("2006-01-02 15:04:05")
|
||||
} else {
|
||||
expiryTime = "N/A"
|
||||
}
|
||||
@@ -780,16 +745,20 @@ var tagCmd = &cobra.Command{
|
||||
Short: "Manage the tags of a node",
|
||||
Aliases: []string{"tags", "t"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
// retrieve flags from CLI
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
tagsToSet, err := cmd.Flags().GetStringSlice("tags")
|
||||
@@ -799,37 +768,33 @@ var tagCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error retrieving list of tags to add to node, %v", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
// Sending tags to node
|
||||
request := &v1.SetTagsRequest{
|
||||
NodeId: identifier,
|
||||
Tags: tagsToSet,
|
||||
}
|
||||
resp, err := client.SetTags(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error while sending tags to headscale: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
SuccessOutput(
|
||||
resp.GetNode(),
|
||||
"Node updated",
|
||||
output,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// Sending tags to node
|
||||
request := &v1.SetTagsRequest{
|
||||
NodeId: identifier,
|
||||
Tags: tagsToSet,
|
||||
}
|
||||
resp, err := client.SetTags(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error while sending tags to headscale: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
SuccessOutput(
|
||||
resp.GetNode(),
|
||||
"Node updated",
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -837,16 +802,20 @@ var approveRoutesCmd = &cobra.Command{
|
||||
Use: "approve-routes",
|
||||
Short: "Manage the approved routes of a node",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
// retrieve flags from CLI
|
||||
identifier, err := GetNodeIdentifier(cmd)
|
||||
identifier, err := cmd.Flags().GetUint64("identifier")
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting node identifier: %s", err),
|
||||
fmt.Sprintf("Error converting ID to integer: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
routes, err := cmd.Flags().GetStringSlice("routes")
|
||||
@@ -856,36 +825,32 @@ var approveRoutesCmd = &cobra.Command{
|
||||
fmt.Sprintf("Error retrieving list of routes to add to node, %v", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
// Sending routes to node
|
||||
request := &v1.SetApprovedRoutesRequest{
|
||||
NodeId: identifier,
|
||||
Routes: routes,
|
||||
}
|
||||
resp, err := client.SetApprovedRoutes(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error while sending routes to headscale: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
SuccessOutput(
|
||||
resp.GetNode(),
|
||||
"Node updated",
|
||||
output,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// Sending routes to node
|
||||
request := &v1.SetApprovedRoutesRequest{
|
||||
NodeId: identifier,
|
||||
Routes: routes,
|
||||
}
|
||||
resp, err := client.SetApprovedRoutes(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error while sending routes to headscale: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
SuccessOutput(
|
||||
resp.GetNode(),
|
||||
"Node updated",
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -41,26 +40,22 @@ var getPolicy = &cobra.Command{
|
||||
Short: "Print the current ACL Policy",
|
||||
Aliases: []string{"show", "view", "fetch"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.GetPolicyRequest{}
|
||||
request := &v1.GetPolicyRequest{}
|
||||
|
||||
response, err := client.GetPolicy(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(pallabpain): Maybe print this better?
|
||||
// This does not pass output as we dont support yaml, json or json-line
|
||||
// output for this command. It is HuJSON already.
|
||||
SuccessOutput("", response.GetPolicy(), "")
|
||||
return nil
|
||||
})
|
||||
response, err := client.GetPolicy(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
|
||||
}
|
||||
|
||||
// TODO(pallabpain): Maybe print this better?
|
||||
// This does not pass output as we dont support yaml, json or json-line
|
||||
// output for this command. It is HuJSON already.
|
||||
SuccessOutput("", response.GetPolicy(), "")
|
||||
},
|
||||
}
|
||||
|
||||
@@ -72,36 +67,31 @@ var setPolicy = &cobra.Command{
|
||||
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
|
||||
Aliases: []string{"put", "update"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
policyPath, _ := cmd.Flags().GetString("file")
|
||||
|
||||
f, err := os.Open(policyPath)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||
return err
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
SuccessOutput(nil, "Policy updated.", "")
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||
}
|
||||
|
||||
SuccessOutput(nil, "Policy updated.", "")
|
||||
},
|
||||
}
|
||||
|
||||
@@ -109,26 +99,23 @@ var checkPolicy = &cobra.Command{
|
||||
Use: "check",
|
||||
Short: "Check the Policy file for errors",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
policyPath, _ := cmd.Flags().GetString("file")
|
||||
|
||||
f, err := os.Open(policyPath)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{})
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(nil, "Policy is valid", "")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -15,10 +14,19 @@ import (
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPreAuthKeyExpiry = "1h"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(preauthkeysCmd)
|
||||
preauthkeysCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)")
|
||||
|
||||
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "User")
|
||||
pakNamespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace")
|
||||
pakNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
pakNamespaceFlag.Hidden = true
|
||||
|
||||
err := preauthkeysCmd.MarkPersistentFlagRequired("user")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
@@ -47,85 +55,81 @@ var listPreAuthKeys = &cobra.Command{
|
||||
Short: "List the preauthkeys for this user",
|
||||
Aliases: []string{"ls", "show"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
request := &v1.ListPreAuthKeysRequest{
|
||||
User: user,
|
||||
}
|
||||
|
||||
response, err := client.ListPreAuthKeys(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListPreAuthKeysRequest{
|
||||
User: user,
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetPreAuthKeys(), "", output)
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{
|
||||
{
|
||||
"ID",
|
||||
"Key",
|
||||
"Reusable",
|
||||
"Ephemeral",
|
||||
"Used",
|
||||
"Expiration",
|
||||
"Created",
|
||||
"Tags",
|
||||
},
|
||||
}
|
||||
for _, key := range response.GetPreAuthKeys() {
|
||||
expiration := "-"
|
||||
if key.GetExpiration() != nil {
|
||||
expiration = ColourTime(key.GetExpiration().AsTime())
|
||||
}
|
||||
|
||||
response, err := client.ListPreAuthKeys(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
aclTags := ""
|
||||
|
||||
for _, tag := range key.GetAclTags() {
|
||||
aclTags += "," + tag
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetPreAuthKeys(), "", output)
|
||||
return nil
|
||||
}
|
||||
aclTags = strings.TrimLeft(aclTags, ",")
|
||||
|
||||
tableData := pterm.TableData{
|
||||
{
|
||||
"ID",
|
||||
"Key",
|
||||
"Reusable",
|
||||
"Ephemeral",
|
||||
"Used",
|
||||
"Expiration",
|
||||
"Created",
|
||||
"Tags",
|
||||
},
|
||||
}
|
||||
for _, key := range response.GetPreAuthKeys() {
|
||||
expiration := "-"
|
||||
if key.GetExpiration() != nil {
|
||||
expiration = ColourTime(key.GetExpiration().AsTime())
|
||||
}
|
||||
tableData = append(tableData, []string{
|
||||
strconv.FormatUint(key.GetId(), 10),
|
||||
key.GetKey(),
|
||||
strconv.FormatBool(key.GetReusable()),
|
||||
strconv.FormatBool(key.GetEphemeral()),
|
||||
strconv.FormatBool(key.GetUsed()),
|
||||
expiration,
|
||||
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||
aclTags,
|
||||
})
|
||||
|
||||
aclTags := ""
|
||||
|
||||
for _, tag := range key.GetAclTags() {
|
||||
aclTags += "," + tag
|
||||
}
|
||||
|
||||
aclTags = strings.TrimLeft(aclTags, ",")
|
||||
|
||||
tableData = append(tableData, []string{
|
||||
strconv.FormatUint(key.GetId(), 10),
|
||||
key.GetKey(),
|
||||
strconv.FormatBool(key.GetReusable()),
|
||||
strconv.FormatBool(key.GetEphemeral()),
|
||||
strconv.FormatBool(key.GetUsed()),
|
||||
expiration,
|
||||
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
||||
aclTags,
|
||||
})
|
||||
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -135,12 +139,11 @@ var createPreAuthKeyCmd = &cobra.Command{
|
||||
Short: "Creates a new preauthkey in the specified user",
|
||||
Aliases: []string{"c", "new"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
reusable, _ := cmd.Flags().GetBool("reusable")
|
||||
@@ -163,7 +166,6 @@ var createPreAuthKeyCmd = &cobra.Command{
|
||||
fmt.Sprintf("Could not parse duration: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := time.Now().UTC().Add(time.Duration(duration))
|
||||
@@ -174,23 +176,20 @@ var createPreAuthKeyCmd = &cobra.Command{
|
||||
|
||||
request.Expiration = timestamppb.New(expiration)
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
response, err := client.CreatePreAuthKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.CreatePreAuthKey(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -206,34 +205,30 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
user, err := cmd.Flags().GetUint64("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
return
|
||||
}
|
||||
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ExpirePreAuthKeyRequest{
|
||||
User: user,
|
||||
Key: args[0],
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
response, err := client.ExpirePreAuthKey(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
request := &v1.ExpirePreAuthKeyRequest{
|
||||
User: user,
|
||||
Key: args[0],
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key expired", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.ExpirePreAuthKey(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key expired", output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func ColourTime(date time.Time) string {
|
||||
dateStr := date.Format(HeadscaleDateTimeFormat)
|
||||
dateStr := date.Format("2006-01-02 15:04:05")
|
||||
|
||||
if date.After(time.Now()) {
|
||||
dateStr = pterm.LightGreen(dateStr)
|
||||
|
||||
@@ -14,6 +14,10 @@ import (
|
||||
"github.com/tcnksm/go-latest"
|
||||
)
|
||||
|
||||
const (
|
||||
deprecateNamespaceMessage = "use --user"
|
||||
)
|
||||
|
||||
var cfgFile string = ""
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServeCommand(t *testing.T) {
|
||||
// Test that the serve command exists and is properly configured
|
||||
assert.NotNil(t, serveCmd)
|
||||
assert.Equal(t, "serve", serveCmd.Use)
|
||||
assert.Equal(t, "Launches the headscale server", serveCmd.Short)
|
||||
assert.NotNil(t, serveCmd.Run)
|
||||
assert.NotNil(t, serveCmd.Args)
|
||||
}
|
||||
|
||||
func TestServeCommandInRootCommand(t *testing.T) {
|
||||
// Test that serve is available as a subcommand of root
|
||||
cmd, _, err := rootCmd.Find([]string{"serve"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "serve", cmd.Name())
|
||||
assert.Equal(t, serveCmd, cmd)
|
||||
}
|
||||
|
||||
func TestServeCommandArgs(t *testing.T) {
|
||||
// Test that the Args function is defined and accepts any arguments
|
||||
// The current implementation always returns nil (accepts any args)
|
||||
assert.NotNil(t, serveCmd.Args)
|
||||
|
||||
// Test the args function directly
|
||||
err := serveCmd.Args(serveCmd, []string{})
|
||||
assert.NoError(t, err, "Args function should accept empty arguments")
|
||||
|
||||
err = serveCmd.Args(serveCmd, []string{"extra", "args"})
|
||||
assert.NoError(t, err, "Args function should accept extra arguments")
|
||||
}
|
||||
|
||||
func TestServeCommandHelp(t *testing.T) {
|
||||
// Test that the command has proper help text
|
||||
assert.NotEmpty(t, serveCmd.Short)
|
||||
assert.Contains(t, serveCmd.Short, "server")
|
||||
assert.Contains(t, serveCmd.Short, "headscale")
|
||||
}
|
||||
|
||||
func TestServeCommandStructure(t *testing.T) {
|
||||
// Test basic command structure
|
||||
assert.Equal(t, "serve", serveCmd.Name())
|
||||
assert.Equal(t, "Launches the headscale server", serveCmd.Short)
|
||||
|
||||
// Test that it has no subcommands (it's a leaf command)
|
||||
subcommands := serveCmd.Commands()
|
||||
assert.Empty(t, subcommands, "Serve command should not have subcommands")
|
||||
}
|
||||
|
||||
// Note: We can't easily test the actual execution of serve because:
|
||||
// 1. It depends on configuration files being present and valid
|
||||
// 2. It calls log.Fatal() which would exit the test process
|
||||
// 3. It tries to start an actual HTTP server which would block forever
|
||||
// 4. It requires database connections and other infrastructure
|
||||
//
|
||||
// In a real refactor, we would:
|
||||
// 1. Extract server initialization logic to a testable function
|
||||
// 2. Use dependency injection for configuration and dependencies
|
||||
// 3. Return errors instead of calling log.Fatal()
|
||||
// 4. Add graceful shutdown capabilities for testing
|
||||
// 5. Allow server startup to be cancelled via context
|
||||
//
|
||||
// For now, we test the command structure and basic properties.
|
||||
@@ -1,55 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pterm/pterm"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
|
||||
DefaultAPIKeyExpiry = "90d"
|
||||
DefaultPreAuthKeyExpiry = "1h"
|
||||
)
|
||||
|
||||
// FilterTableColumns filters table columns based on --columns flag
|
||||
func FilterTableColumns(cmd *cobra.Command, tableData pterm.TableData) pterm.TableData {
|
||||
columns, _ := cmd.Flags().GetString("columns")
|
||||
if columns == "" || len(tableData) == 0 {
|
||||
return tableData
|
||||
}
|
||||
|
||||
headers := tableData[0]
|
||||
wantedColumns := strings.Split(columns, ",")
|
||||
|
||||
// Find column indices
|
||||
var indices []int
|
||||
for _, wanted := range wantedColumns {
|
||||
wanted = strings.TrimSpace(wanted)
|
||||
for i, header := range headers {
|
||||
if strings.EqualFold(header, wanted) {
|
||||
indices = append(indices, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(indices) == 0 {
|
||||
return tableData
|
||||
}
|
||||
|
||||
// Filter all rows
|
||||
filtered := make(pterm.TableData, len(tableData))
|
||||
for i, row := range tableData {
|
||||
newRow := make([]string, len(indices))
|
||||
for j, idx := range indices {
|
||||
if idx < len(row) {
|
||||
newRow[j] = row[idx]
|
||||
}
|
||||
}
|
||||
filtered[i] = newRow
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
survey "github.com/AlecAivazis/survey/v2"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
@@ -17,23 +15,25 @@ import (
|
||||
)
|
||||
|
||||
func usernameAndIDFlag(cmd *cobra.Command) {
|
||||
cmd.Flags().StringP("user", "u", "", "User identifier (ID, name, or email)")
|
||||
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
|
||||
cmd.Flags().StringP("name", "n", "", "Username")
|
||||
}
|
||||
|
||||
// userIDFromFlag returns the user ID using smart lookup.
|
||||
// If no user is specified, it will exit the program with an error.
|
||||
func userIDFromFlag(cmd *cobra.Command) uint64 {
|
||||
userID, err := GetUserIdentifier(cmd)
|
||||
if err != nil {
|
||||
// usernameAndIDFromFlag returns the username and ID from the flags of the command.
|
||||
// If both are empty, it will exit the program with an error.
|
||||
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||
username, _ := cmd.Flags().GetString("name")
|
||||
identifier, _ := cmd.Flags().GetInt64("identifier")
|
||||
if username == "" && identifier < 0 {
|
||||
err := errors.New("--name or --identifier flag is required")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot identify user: "+err.Error(),
|
||||
GetOutputFlag(cmd),
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
return userID
|
||||
return uint64(identifier), username
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -43,18 +43,14 @@ func init() {
|
||||
createUserCmd.Flags().StringP("email", "e", "", "Email")
|
||||
createUserCmd.Flags().StringP("picture-url", "p", "", "Profile picture URL")
|
||||
userCmd.AddCommand(listUsersCmd)
|
||||
// Smart lookup filters - can be used individually or combined
|
||||
listUsersCmd.Flags().StringP("user", "u", "", "Filter by user (ID, name, or email)")
|
||||
listUsersCmd.Flags().Uint64P("id", "", 0, "Filter by user ID")
|
||||
listUsersCmd.Flags().StringP("name", "n", "", "Filter by username")
|
||||
listUsersCmd.Flags().StringP("email", "e", "", "Filter by email address")
|
||||
listUsersCmd.Flags().String("columns", "", "Comma-separated list of columns to display (ID,Name,Username,Email,Created)")
|
||||
usernameAndIDFlag(listUsersCmd)
|
||||
listUsersCmd.Flags().StringP("email", "e", "", "Email")
|
||||
userCmd.AddCommand(destroyUserCmd)
|
||||
usernameAndIDFlag(destroyUserCmd)
|
||||
userCmd.AddCommand(renameUserCmd)
|
||||
usernameAndIDFlag(renameUserCmd)
|
||||
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
|
||||
renameUserCmd.MarkFlagRequired("new-name")
|
||||
renameNodeCmd.MarkFlagRequired("new-name")
|
||||
}
|
||||
|
||||
var errMissingParameter = errors.New("missing parameters")
|
||||
@@ -77,9 +73,16 @@ var createUserCmd = &cobra.Command{
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
userName := args[0]
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
|
||||
|
||||
request := &v1.CreateUserRequest{Name: userName}
|
||||
|
||||
if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" {
|
||||
@@ -100,75 +103,61 @@ var createUserCmd = &cobra.Command{
|
||||
),
|
||||
output,
|
||||
)
|
||||
return
|
||||
}
|
||||
request.PictureUrl = pictureURL
|
||||
}
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
|
||||
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
|
||||
|
||||
response, err := client.CreateUser(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot create user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetUser(), "User created", output)
|
||||
return nil
|
||||
})
|
||||
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
|
||||
response, err := client.CreateUser(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot create user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetUser(), "User created", output)
|
||||
},
|
||||
}
|
||||
|
||||
var destroyUserCmd = &cobra.Command{
|
||||
Use: "destroy --user USER",
|
||||
Use: "destroy --identifier ID or --name NAME",
|
||||
Short: "Destroys a user",
|
||||
Aliases: []string{"delete"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
id := userIDFromFlag(cmd)
|
||||
id, username := usernameAndIDFromFlag(cmd)
|
||||
request := &v1.ListUsersRequest{
|
||||
Id: id,
|
||||
Name: username,
|
||||
Id: id,
|
||||
}
|
||||
|
||||
var user *v1.User
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
users, err := client.ListUsers(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
user = users.GetUsers()[0]
|
||||
return nil
|
||||
})
|
||||
users, err := client.ListUsers(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
user := users.GetUsers()[0]
|
||||
|
||||
confirm := false
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
if !force {
|
||||
@@ -185,24 +174,17 @@ var destroyUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if confirm || force {
|
||||
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.DeleteUserRequest{Id: user.GetId()}
|
||||
request := &v1.DeleteUserRequest{Id: user.GetId()}
|
||||
|
||||
response, err := client.DeleteUser(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot destroy user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
SuccessOutput(response, "User destroyed", output)
|
||||
return nil
|
||||
})
|
||||
response, err := client.DeleteUser(ctx, request)
|
||||
if err != nil {
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot destroy user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
SuccessOutput(response, "User destroyed", output)
|
||||
} else {
|
||||
SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output)
|
||||
}
|
||||
@@ -214,76 +196,61 @@ var listUsersCmd = &cobra.Command{
|
||||
Short: "List all the users",
|
||||
Aliases: []string{"ls", "show"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListUsersRequest{}
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
// Check for smart lookup flag first
|
||||
userFlag, _ := cmd.Flags().GetString("user")
|
||||
if userFlag != "" {
|
||||
// Use smart lookup to determine filter type
|
||||
if id, err := strconv.ParseUint(userFlag, 10, 64); err == nil && id > 0 {
|
||||
request.Id = id
|
||||
} else if strings.Contains(userFlag, "@") {
|
||||
request.Email = userFlag
|
||||
} else {
|
||||
request.Name = userFlag
|
||||
}
|
||||
} else {
|
||||
// Check specific filter flags
|
||||
if id, _ := cmd.Flags().GetUint64("id"); id > 0 {
|
||||
request.Id = id
|
||||
} else if name, _ := cmd.Flags().GetString("name"); name != "" {
|
||||
request.Name = name
|
||||
} else if email, _ := cmd.Flags().GetString("email"); email != "" {
|
||||
request.Email = email
|
||||
}
|
||||
}
|
||||
request := &v1.ListUsersRequest{}
|
||||
|
||||
response, err := client.ListUsers(ctx, request)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot get users: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
id, _ := cmd.Flags().GetInt64("identifier")
|
||||
username, _ := cmd.Flags().GetString("name")
|
||||
email, _ := cmd.Flags().GetString("email")
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetUsers(), "", output)
|
||||
return nil
|
||||
}
|
||||
// filter by one param at most
|
||||
switch {
|
||||
case id > 0:
|
||||
request.Id = uint64(id)
|
||||
case username != "":
|
||||
request.Name = username
|
||||
case email != "":
|
||||
request.Email = email
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}}
|
||||
for _, user := range response.GetUsers() {
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
strconv.FormatUint(user.GetId(), 10),
|
||||
user.GetDisplayName(),
|
||||
user.GetName(),
|
||||
user.GetEmail(),
|
||||
user.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
|
||||
},
|
||||
)
|
||||
}
|
||||
tableData = FilterTableColumns(cmd, tableData)
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
response, err := client.ListUsers(ctx, request)
|
||||
if err != nil {
|
||||
// Error already handled in closure
|
||||
return
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot get users: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetUsers(), "", output)
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}}
|
||||
for _, user := range response.GetUsers() {
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
strconv.FormatUint(user.GetId(), 10),
|
||||
user.GetDisplayName(),
|
||||
user.GetName(),
|
||||
user.GetEmail(),
|
||||
user.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
|
||||
},
|
||||
)
|
||||
}
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -293,56 +260,52 @@ var renameUserCmd = &cobra.Command{
|
||||
Short: "Renames a user",
|
||||
Aliases: []string{"mv"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
id, username := usernameAndIDFromFlag(cmd)
|
||||
listReq := &v1.ListUsersRequest{
|
||||
Name: username,
|
||||
Id: id,
|
||||
}
|
||||
|
||||
users, err := client.ListUsers(ctx, listReq)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
id := userIDFromFlag(cmd)
|
||||
newName, _ := cmd.Flags().GetString("new-name")
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
listReq := &v1.ListUsersRequest{
|
||||
Id: id,
|
||||
}
|
||||
|
||||
users, err := client.ListUsers(ctx, listReq)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
renameReq := &v1.RenameUserRequest{
|
||||
OldId: id,
|
||||
NewName: newName,
|
||||
}
|
||||
|
||||
response, err := client.RenameUser(ctx, renameReq)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetUser(), "User renamed", output)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
renameReq := &v1.RenameUserRequest{
|
||||
OldId: id,
|
||||
NewName: newName,
|
||||
}
|
||||
|
||||
response, err := client.RenameUser(ctx, renameReq)
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetUser(), "User renamed", output)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -5,23 +5,24 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
|
||||
SocketWritePermissions = 0o666
|
||||
)
|
||||
|
||||
func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
|
||||
cfg, err := types.LoadServerConfig()
|
||||
if err != nil {
|
||||
@@ -71,7 +72,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
|
||||
// Try to give the user better feedback if we cannot write to the headscale
|
||||
// socket.
|
||||
socket, err := os.OpenFile(cfg.UnixSocket, os.O_WRONLY, 0o666) // nolint
|
||||
socket, err := os.OpenFile(cfg.UnixSocket, os.O_WRONLY, SocketWritePermissions) // nolint
|
||||
if err != nil {
|
||||
if os.IsPermission(err) {
|
||||
log.Fatal().
|
||||
@@ -199,152 +200,3 @@ func (t tokenAuth) GetRequestMetadata(
|
||||
func (tokenAuth) RequireTransportSecurity() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetOutputFlag returns the output flag value (never fails)
|
||||
func GetOutputFlag(cmd *cobra.Command) string {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
return output
|
||||
}
|
||||
|
||||
|
||||
// GetNodeIdentifier returns the node ID using smart lookup via gRPC ListNodes call
|
||||
func GetNodeIdentifier(cmd *cobra.Command) (uint64, error) {
|
||||
nodeFlag, _ := cmd.Flags().GetString("node")
|
||||
|
||||
// Use --node flag
|
||||
if nodeFlag == "" {
|
||||
return 0, fmt.Errorf("--node flag is required")
|
||||
}
|
||||
|
||||
// Use smart lookup via gRPC
|
||||
return lookupNodeBySpecifier(nodeFlag)
|
||||
}
|
||||
|
||||
// lookupNodeBySpecifier performs smart lookup of a node by ID, name, hostname, or IP
|
||||
func lookupNodeBySpecifier(specifier string) (uint64, error) {
|
||||
var nodeID uint64
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListNodesRequest{}
|
||||
|
||||
// Detect what type of specifier this is and set appropriate filter
|
||||
if id, err := strconv.ParseUint(specifier, 10, 64); err == nil && id > 0 {
|
||||
// Looks like a numeric ID
|
||||
request.Id = id
|
||||
} else if isIPAddress(specifier) {
|
||||
// Looks like an IP address
|
||||
request.IpAddresses = []string{specifier}
|
||||
} else {
|
||||
// Treat as hostname/name
|
||||
request.Name = specifier
|
||||
}
|
||||
|
||||
response, err := client.ListNodes(ctx, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup node: %w", err)
|
||||
}
|
||||
|
||||
nodes := response.GetNodes()
|
||||
if len(nodes) == 0 {
|
||||
return fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
if len(nodes) > 1 {
|
||||
var nodeInfo []string
|
||||
for _, node := range nodes {
|
||||
nodeInfo = append(nodeInfo, fmt.Sprintf("ID=%d name=%s", node.GetId(), node.GetName()))
|
||||
}
|
||||
return fmt.Errorf("multiple nodes found matching '%s': %s", specifier, strings.Join(nodeInfo, ", "))
|
||||
}
|
||||
|
||||
// Exactly one match - this is what we want
|
||||
nodeID = nodes[0].GetId()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return nodeID, nil
|
||||
}
|
||||
|
||||
// isIPAddress checks if a string looks like an IP address
|
||||
func isIPAddress(s string) bool {
|
||||
// Try parsing as IP address (both IPv4 and IPv6)
|
||||
if net.ParseIP(s) != nil {
|
||||
return true
|
||||
}
|
||||
// Try parsing as CIDR
|
||||
if _, _, err := net.ParseCIDR(s); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetUserIdentifier returns the user ID using smart lookup via gRPC ListUsers call
|
||||
func GetUserIdentifier(cmd *cobra.Command) (uint64, error) {
|
||||
userFlag, _ := cmd.Flags().GetString("user")
|
||||
nameFlag, _ := cmd.Flags().GetString("name")
|
||||
|
||||
var specifier string
|
||||
|
||||
// Determine which flag was used (prefer --user, fall back to legacy flags)
|
||||
if userFlag != "" {
|
||||
specifier = userFlag
|
||||
} else if nameFlag != "" {
|
||||
specifier = nameFlag
|
||||
} else {
|
||||
return 0, fmt.Errorf("--user flag is required")
|
||||
}
|
||||
|
||||
// Use smart lookup via gRPC
|
||||
return lookupUserBySpecifier(specifier)
|
||||
}
|
||||
|
||||
// lookupUserBySpecifier performs smart lookup of a user by ID, name, or email
|
||||
func lookupUserBySpecifier(specifier string) (uint64, error) {
|
||||
var userID uint64
|
||||
|
||||
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
|
||||
request := &v1.ListUsersRequest{}
|
||||
|
||||
// Detect what type of specifier this is and set appropriate filter
|
||||
if id, err := strconv.ParseUint(specifier, 10, 64); err == nil && id > 0 {
|
||||
// Looks like a numeric ID
|
||||
request.Id = id
|
||||
} else if strings.Contains(specifier, "@") {
|
||||
// Looks like an email address
|
||||
request.Email = specifier
|
||||
} else {
|
||||
// Treat as username
|
||||
request.Name = specifier
|
||||
}
|
||||
|
||||
response, err := client.ListUsers(ctx, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup user: %w", err)
|
||||
}
|
||||
|
||||
users := response.GetUsers()
|
||||
if len(users) == 0 {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
if len(users) > 1 {
|
||||
var userInfo []string
|
||||
for _, user := range users {
|
||||
userInfo = append(userInfo, fmt.Sprintf("ID=%d name=%s email=%s", user.GetId(), user.GetName(), user.GetEmail()))
|
||||
}
|
||||
return fmt.Errorf("multiple users found matching '%s': %s", specifier, strings.Join(userInfo, ", "))
|
||||
}
|
||||
|
||||
// Exactly one match - this is what we want
|
||||
userID = users[0].GetId()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHasMachineOutputFlag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no machine output flags",
|
||||
args: []string{"headscale", "users", "list"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "json flag present",
|
||||
args: []string{"headscale", "users", "list", "json"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "json-line flag present",
|
||||
args: []string{"headscale", "nodes", "list", "json-line"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "yaml flag present",
|
||||
args: []string{"headscale", "apikeys", "list", "yaml"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "mixed flags with json",
|
||||
args: []string{"headscale", "--config", "/tmp/config.yaml", "users", "list", "json"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "flag as part of longer argument",
|
||||
args: []string{"headscale", "users", "create", "json-user@example.com"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Save original os.Args
|
||||
originalArgs := os.Args
|
||||
defer func() { os.Args = originalArgs }()
|
||||
|
||||
// Set os.Args to test case
|
||||
os.Args = tt.args
|
||||
|
||||
result := HasMachineOutputFlag()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result interface{}
|
||||
override string
|
||||
outputFormat string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "default format returns override",
|
||||
result: map[string]string{"test": "value"},
|
||||
override: "Human readable output",
|
||||
outputFormat: "",
|
||||
expected: "Human readable output",
|
||||
},
|
||||
{
|
||||
name: "default format with empty override",
|
||||
result: map[string]string{"test": "value"},
|
||||
override: "",
|
||||
outputFormat: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "json format",
|
||||
result: map[string]string{"name": "test", "id": "123"},
|
||||
override: "Human readable",
|
||||
outputFormat: "json",
|
||||
expected: "{\n\t\"id\": \"123\",\n\t\"name\": \"test\"\n}",
|
||||
},
|
||||
{
|
||||
name: "json-line format",
|
||||
result: map[string]string{"name": "test", "id": "123"},
|
||||
override: "Human readable",
|
||||
outputFormat: "json-line",
|
||||
expected: "{\"id\":\"123\",\"name\":\"test\"}",
|
||||
},
|
||||
{
|
||||
name: "yaml format",
|
||||
result: map[string]string{"name": "test", "id": "123"},
|
||||
override: "Human readable",
|
||||
outputFormat: "yaml",
|
||||
expected: "id: \"123\"\nname: test\n",
|
||||
},
|
||||
{
|
||||
name: "invalid format returns override",
|
||||
result: map[string]string{"test": "value"},
|
||||
override: "Human readable output",
|
||||
outputFormat: "invalid",
|
||||
expected: "Human readable output",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := output(tt.result, tt.override, tt.outputFormat)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputWithComplexData(t *testing.T) {
|
||||
// Test with more complex data structures
|
||||
complexData := struct {
|
||||
Users []struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
ID int `json:"id" yaml:"id"`
|
||||
} `json:"users" yaml:"users"`
|
||||
}{
|
||||
Users: []struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
ID int `json:"id" yaml:"id"`
|
||||
}{
|
||||
{Name: "user1", ID: 1},
|
||||
{Name: "user2", ID: 2},
|
||||
},
|
||||
}
|
||||
|
||||
// Test JSON output
|
||||
jsonResult := output(complexData, "override", "json")
|
||||
assert.Contains(t, jsonResult, "\"users\":")
|
||||
assert.Contains(t, jsonResult, "\"name\": \"user1\"")
|
||||
assert.Contains(t, jsonResult, "\"id\": 1")
|
||||
|
||||
// Test YAML output
|
||||
yamlResult := output(complexData, "override", "yaml")
|
||||
assert.Contains(t, yamlResult, "users:")
|
||||
assert.Contains(t, yamlResult, "name: user1")
|
||||
assert.Contains(t, yamlResult, "id: 1")
|
||||
}
|
||||
|
||||
func TestOutputWithNilData(t *testing.T) {
|
||||
// Test with nil data
|
||||
result := output(nil, "fallback", "json")
|
||||
assert.Equal(t, "null", result)
|
||||
|
||||
result = output(nil, "fallback", "yaml")
|
||||
assert.Equal(t, "null\n", result)
|
||||
|
||||
result = output(nil, "fallback", "")
|
||||
assert.Equal(t, "fallback", result)
|
||||
}
|
||||
|
||||
func TestOutputWithEmptyData(t *testing.T) {
|
||||
// Test with empty slice
|
||||
emptySlice := []string{}
|
||||
result := output(emptySlice, "fallback", "json")
|
||||
assert.Equal(t, "[]", result)
|
||||
|
||||
// Test with empty map
|
||||
emptyMap := map[string]string{}
|
||||
result = output(emptyMap, "fallback", "json")
|
||||
assert.Equal(t, "{}", result)
|
||||
}
|
||||
@@ -11,10 +11,10 @@ func init() {
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print the version",
|
||||
Long: "The version of headscale",
|
||||
Short: "Print the version.",
|
||||
Long: "The version of headscale.",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output := GetOutputFlag(cmd)
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
SuccessOutput(map[string]string{
|
||||
"version": types.Version,
|
||||
"commit": types.GitCommitHash,
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVersionCommand(t *testing.T) {
|
||||
// Test that version command exists
|
||||
assert.NotNil(t, versionCmd)
|
||||
assert.Equal(t, "version", versionCmd.Use)
|
||||
assert.Equal(t, "Print the version.", versionCmd.Short)
|
||||
assert.Equal(t, "The version of headscale.", versionCmd.Long)
|
||||
}
|
||||
|
||||
func TestVersionCommandStructure(t *testing.T) {
|
||||
// Test command is properly added to root
|
||||
found := false
|
||||
for _, cmd := range rootCmd.Commands() {
|
||||
if cmd.Use == "version" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "version command should be added to root command")
|
||||
}
|
||||
|
||||
func TestVersionCommandFlags(t *testing.T) {
|
||||
// Version command should inherit output flag from root as persistent flag
|
||||
outputFlag := versionCmd.Flag("output")
|
||||
if outputFlag == nil {
|
||||
// Try persistent flags from root
|
||||
outputFlag = rootCmd.PersistentFlags().Lookup("output")
|
||||
}
|
||||
assert.NotNil(t, outputFlag, "version command should have access to output flag")
|
||||
}
|
||||
|
||||
func TestVersionCommandRun(t *testing.T) {
|
||||
// Test that Run function is set
|
||||
assert.NotNil(t, versionCmd.Run)
|
||||
|
||||
// We can't easily test the actual execution without mocking SuccessOutput
|
||||
// but we can verify the function exists and has the right signature
|
||||
}
|
||||
@@ -90,6 +90,32 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
log.Printf("Starting test: %s", config.TestPattern)
|
||||
|
||||
// Start stats collection for container resource monitoring (if enabled)
|
||||
var statsCollector *StatsCollector
|
||||
if config.Stats {
|
||||
var err error
|
||||
statsCollector, err = NewStatsCollector()
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||
}
|
||||
statsCollector = nil
|
||||
}
|
||||
|
||||
if statsCollector != nil {
|
||||
defer statsCollector.Close()
|
||||
|
||||
// Start stats collection immediately - no need for complex retry logic
|
||||
// The new implementation monitors Docker events and will catch containers as they start
|
||||
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||
}
|
||||
}
|
||||
defer statsCollector.StopCollection()
|
||||
}
|
||||
}
|
||||
|
||||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||
|
||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||
@@ -105,6 +131,20 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
// Always list control files regardless of test outcome
|
||||
listControlFiles(logsDir)
|
||||
|
||||
// Print stats summary and check memory limits if enabled
|
||||
if config.Stats && statsCollector != nil {
|
||||
violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit)
|
||||
if len(violations) > 0 {
|
||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||
log.Printf("=================================")
|
||||
for _, violation := range violations {
|
||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||
}
|
||||
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
||||
}
|
||||
}
|
||||
|
||||
shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0)
|
||||
if shouldCleanup {
|
||||
if config.Verbose {
|
||||
@@ -379,10 +419,37 @@ func getDockerSocketPath() string {
|
||||
return "/var/run/docker.sock"
|
||||
}
|
||||
|
||||
// ensureImageAvailable pulls the specified Docker image to ensure it's available.
|
||||
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
||||
if err != nil {
|
||||
if client.IsErrNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ensureImageAvailable checks if the image is available locally first, then pulls if needed.
|
||||
func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error {
|
||||
// First check if image is available locally
|
||||
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check local image availability: %w", err)
|
||||
}
|
||||
|
||||
if available {
|
||||
if verbose {
|
||||
log.Printf("Image %s is available locally", imageName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Image not available locally, try to pull it
|
||||
if verbose {
|
||||
log.Printf("Pulling image %s...", imageName)
|
||||
log.Printf("Image %s not found locally, pulling...", imageName)
|
||||
}
|
||||
|
||||
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
||||
|
||||
@@ -190,7 +190,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
}
|
||||
}
|
||||
|
||||
// checkGolangImage verifies we can access the golang Docker image.
|
||||
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
@@ -205,17 +205,40 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
goVersion := detectGoVersion()
|
||||
imageName := "golang:" + goVersion
|
||||
|
||||
// Check if we can pull the image
|
||||
// First check if image is available locally
|
||||
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "FAIL",
|
||||
Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err),
|
||||
Suggestions: []string{
|
||||
"Check Docker daemon status",
|
||||
"Try: docker images | grep golang",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if available {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "PASS",
|
||||
Message: fmt.Sprintf("Golang image %s is available locally", imageName),
|
||||
}
|
||||
}
|
||||
|
||||
// Image not available locally, try to pull it
|
||||
err = ensureImageAvailable(ctx, cli, imageName, false)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "FAIL",
|
||||
Message: fmt.Sprintf("Cannot pull golang image %s: %v", imageName, err),
|
||||
Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err),
|
||||
Suggestions: []string{
|
||||
"Check internet connectivity",
|
||||
"Verify Docker Hub access",
|
||||
"Try: docker pull " + imageName,
|
||||
"Or run tests offline if image was pulled previously",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -223,7 +246,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "PASS",
|
||||
Message: fmt.Sprintf("Golang image %s is available", imageName),
|
||||
Message: fmt.Sprintf("Golang image %s is now available", imageName),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,9 @@ type RunConfig struct {
|
||||
KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"`
|
||||
LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"`
|
||||
Verbose bool `flag:"verbose,default=false,Verbose output"`
|
||||
Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"`
|
||||
HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"`
|
||||
TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"`
|
||||
}
|
||||
|
||||
// runIntegrationTest executes the integration test workflow.
|
||||
|
||||
468
cmd/hi/stats.go
Normal file
468
cmd/hi/stats.go
Normal file
@@ -0,0 +1,468 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/events"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
// ContainerStats represents statistics for a single container
|
||||
type ContainerStats struct {
|
||||
ContainerID string
|
||||
ContainerName string
|
||||
Stats []StatsSample
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// StatsSample represents a single stats measurement
|
||||
type StatsSample struct {
|
||||
Timestamp time.Time
|
||||
CPUUsage float64 // CPU usage percentage
|
||||
MemoryMB float64 // Memory usage in MB
|
||||
}
|
||||
|
||||
// StatsCollector manages collection of container statistics
|
||||
type StatsCollector struct {
|
||||
client *client.Client
|
||||
containers map[string]*ContainerStats
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mutex sync.RWMutex
|
||||
collectionStarted bool
|
||||
}
|
||||
|
||||
// NewStatsCollector creates a new stats collector instance
|
||||
func NewStatsCollector() (*StatsCollector, error) {
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Docker client: %w", err)
|
||||
}
|
||||
|
||||
return &StatsCollector{
|
||||
client: cli,
|
||||
containers: make(map[string]*ContainerStats),
|
||||
stopChan: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID
|
||||
func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error {
|
||||
sc.mutex.Lock()
|
||||
defer sc.mutex.Unlock()
|
||||
|
||||
if sc.collectionStarted {
|
||||
return fmt.Errorf("stats collection already started")
|
||||
}
|
||||
|
||||
sc.collectionStarted = true
|
||||
|
||||
// Start monitoring existing containers
|
||||
sc.wg.Add(1)
|
||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||
|
||||
// Start Docker events monitoring for new containers
|
||||
sc.wg.Add(1)
|
||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||
|
||||
if verbose {
|
||||
log.Printf("Started container monitoring for run ID %s", runID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopCollection stops all stats collection
|
||||
func (sc *StatsCollector) StopCollection() {
|
||||
// Check if already stopped without holding lock
|
||||
sc.mutex.RLock()
|
||||
if !sc.collectionStarted {
|
||||
sc.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
// Signal stop to all goroutines
|
||||
close(sc.stopChan)
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
sc.wg.Wait()
|
||||
|
||||
// Mark as stopped
|
||||
sc.mutex.Lock()
|
||||
sc.collectionStarted = false
|
||||
sc.mutex.Unlock()
|
||||
}
|
||||
|
||||
// monitorExistingContainers checks for existing containers that match our criteria
|
||||
func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) {
|
||||
defer sc.wg.Done()
|
||||
|
||||
containers, err := sc.client.ContainerList(ctx, container.ListOptions{})
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Failed to list existing containers: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, cont := range containers {
|
||||
if sc.shouldMonitorContainer(cont, runID) {
|
||||
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitorDockerEvents listens for container start events and begins monitoring relevant containers
|
||||
func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) {
|
||||
defer sc.wg.Done()
|
||||
|
||||
filter := filters.NewArgs()
|
||||
filter.Add("type", "container")
|
||||
filter.Add("event", "start")
|
||||
|
||||
eventOptions := events.ListOptions{
|
||||
Filters: filter,
|
||||
}
|
||||
|
||||
events, errs := sc.client.Events(ctx, eventOptions)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sc.stopChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event := <-events:
|
||||
if event.Type == "container" && event.Action == "start" {
|
||||
// Get container details
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to types.Container format for consistency
|
||||
cont := types.Container{
|
||||
ID: containerInfo.ID,
|
||||
Names: []string{containerInfo.Name},
|
||||
Labels: containerInfo.Config.Labels,
|
||||
}
|
||||
|
||||
if sc.shouldMonitorContainer(cont, runID) {
|
||||
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
|
||||
}
|
||||
}
|
||||
case err := <-errs:
|
||||
if verbose {
|
||||
log.Printf("Error in Docker events stream: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldMonitorContainer determines if a container should be monitored
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
||||
// Check if it has the correct run ID label
|
||||
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if it's an hs- or ts- container
|
||||
for _, name := range cont.Names {
|
||||
containerName := strings.TrimPrefix(name, "/")
|
||||
if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// startStatsForContainer begins stats collection for a specific container
|
||||
func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) {
|
||||
containerName = strings.TrimPrefix(containerName, "/")
|
||||
|
||||
sc.mutex.Lock()
|
||||
// Check if we're already monitoring this container
|
||||
if _, exists := sc.containers[containerID]; exists {
|
||||
sc.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
sc.containers[containerID] = &ContainerStats{
|
||||
ContainerID: containerID,
|
||||
ContainerName: containerName,
|
||||
Stats: make([]StatsSample, 0),
|
||||
}
|
||||
sc.mutex.Unlock()
|
||||
|
||||
if verbose {
|
||||
log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12])
|
||||
}
|
||||
|
||||
sc.wg.Add(1)
|
||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||
}
|
||||
|
||||
// collectStatsForContainer collects stats for a specific container using Docker API streaming
|
||||
func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) {
|
||||
defer sc.wg.Done()
|
||||
|
||||
// Use Docker API streaming stats - much more efficient than CLI
|
||||
statsResponse, err := sc.client.ContainerStats(ctx, containerID, true)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer statsResponse.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(statsResponse.Body)
|
||||
var prevStats *container.Stats
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sc.stopChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
var stats container.Stats
|
||||
if err := decoder.Decode(&stats); err != nil {
|
||||
// EOF is expected when container stops or stream ends
|
||||
if err.Error() != "EOF" && verbose {
|
||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate CPU percentage (only if we have previous stats)
|
||||
var cpuPercent float64
|
||||
if prevStats != nil {
|
||||
cpuPercent = calculateCPUPercent(prevStats, &stats)
|
||||
}
|
||||
|
||||
// Calculate memory usage in MB
|
||||
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
|
||||
|
||||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||
if prevStats != nil {
|
||||
// Get container stats reference without holding the main mutex
|
||||
var containerStats *ContainerStats
|
||||
var exists bool
|
||||
|
||||
sc.mutex.RLock()
|
||||
containerStats, exists = sc.containers[containerID]
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
if exists && containerStats != nil {
|
||||
containerStats.mutex.Lock()
|
||||
containerStats.Stats = append(containerStats.Stats, StatsSample{
|
||||
Timestamp: time.Now(),
|
||||
CPUUsage: cpuPercent,
|
||||
MemoryMB: memoryMB,
|
||||
})
|
||||
containerStats.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Save current stats for next iteration
|
||||
prevStats = &stats
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateCPUPercent calculates CPU usage percentage from Docker stats
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
||||
// CPU calculation based on Docker's implementation
|
||||
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||
|
||||
if systemDelta > 0 && cpuDelta >= 0 {
|
||||
// Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100
|
||||
numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage))
|
||||
if numCPUs == 0 {
|
||||
// Fallback: if PercpuUsage is not available, assume 1 CPU
|
||||
numCPUs = 1.0
|
||||
}
|
||||
return (cpuDelta / systemDelta) * numCPUs * 100.0
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// ContainerStatsSummary represents summary statistics for a container
|
||||
type ContainerStatsSummary struct {
|
||||
ContainerName string
|
||||
SampleCount int
|
||||
CPU StatsSummary
|
||||
Memory StatsSummary
|
||||
}
|
||||
|
||||
// MemoryViolation represents a container that exceeded the memory limit
|
||||
type MemoryViolation struct {
|
||||
ContainerName string
|
||||
MaxMemoryMB float64
|
||||
LimitMB float64
|
||||
}
|
||||
|
||||
// StatsSummary represents min, max, and average for a metric
|
||||
type StatsSummary struct {
|
||||
Min float64
|
||||
Max float64
|
||||
Average float64
|
||||
}
|
||||
|
||||
// GetSummary returns a summary of collected statistics
|
||||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||
// Take snapshot of container references without holding main lock long
|
||||
sc.mutex.RLock()
|
||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||
for _, containerStats := range sc.containers {
|
||||
containerRefs = append(containerRefs, containerStats)
|
||||
}
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||
|
||||
for _, containerStats := range containerRefs {
|
||||
containerStats.mutex.RLock()
|
||||
stats := make([]StatsSample, len(containerStats.Stats))
|
||||
copy(stats, containerStats.Stats)
|
||||
containerName := containerStats.ContainerName
|
||||
containerStats.mutex.RUnlock()
|
||||
|
||||
if len(stats) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
summary := ContainerStatsSummary{
|
||||
ContainerName: containerName,
|
||||
SampleCount: len(stats),
|
||||
}
|
||||
|
||||
// Calculate CPU stats
|
||||
cpuValues := make([]float64, len(stats))
|
||||
memoryValues := make([]float64, len(stats))
|
||||
|
||||
for i, sample := range stats {
|
||||
cpuValues[i] = sample.CPUUsage
|
||||
memoryValues[i] = sample.MemoryMB
|
||||
}
|
||||
|
||||
summary.CPU = calculateStatsSummary(cpuValues)
|
||||
summary.Memory = calculateStatsSummary(memoryValues)
|
||||
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
|
||||
// Sort by container name for consistent output
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
return summaries[i].ContainerName < summaries[j].ContainerName
|
||||
})
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
// calculateStatsSummary calculates min, max, and average for a slice of values
|
||||
func calculateStatsSummary(values []float64) StatsSummary {
|
||||
if len(values) == 0 {
|
||||
return StatsSummary{}
|
||||
}
|
||||
|
||||
min := values[0]
|
||||
max := values[0]
|
||||
sum := 0.0
|
||||
|
||||
for _, value := range values {
|
||||
if value < min {
|
||||
min = value
|
||||
}
|
||||
if value > max {
|
||||
max = value
|
||||
}
|
||||
sum += value
|
||||
}
|
||||
|
||||
return StatsSummary{
|
||||
Min: min,
|
||||
Max: max,
|
||||
Average: sum / float64(len(values)),
|
||||
}
|
||||
}
|
||||
|
||||
// PrintSummary prints the statistics summary to the console
|
||||
func (sc *StatsCollector) PrintSummary() {
|
||||
summaries := sc.GetSummary()
|
||||
|
||||
if len(summaries) == 0 {
|
||||
log.Printf("No container statistics collected")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Container Resource Usage Summary:")
|
||||
log.Printf("================================")
|
||||
|
||||
for _, summary := range summaries {
|
||||
log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount)
|
||||
log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%",
|
||||
summary.CPU.Min, summary.CPU.Max, summary.CPU.Average)
|
||||
log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB",
|
||||
summary.Memory.Min, summary.Memory.Max, summary.Memory.Average)
|
||||
log.Printf("")
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryLimits checks if any containers exceeded their memory limits
|
||||
func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
|
||||
if hsLimitMB <= 0 && tsLimitMB <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
summaries := sc.GetSummary()
|
||||
var violations []MemoryViolation
|
||||
|
||||
for _, summary := range summaries {
|
||||
var limitMB float64
|
||||
if strings.HasPrefix(summary.ContainerName, "hs-") {
|
||||
limitMB = hsLimitMB
|
||||
} else if strings.HasPrefix(summary.ContainerName, "ts-") {
|
||||
limitMB = tsLimitMB
|
||||
} else {
|
||||
continue // Skip containers that don't match our patterns
|
||||
}
|
||||
|
||||
if limitMB > 0 && summary.Memory.Max > limitMB {
|
||||
violations = append(violations, MemoryViolation{
|
||||
ContainerName: summary.ContainerName,
|
||||
MaxMemoryMB: summary.Memory.Max,
|
||||
LimitMB: limitMB,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return violations
|
||||
}
|
||||
|
||||
// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any
|
||||
func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
|
||||
sc.PrintSummary()
|
||||
return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB)
|
||||
}
|
||||
|
||||
// Close closes the stats collector and cleans up resources
|
||||
func (sc *StatsCollector) Close() error {
|
||||
sc.StopCollection()
|
||||
return sc.client.Close()
|
||||
}
|
||||
@@ -225,9 +225,11 @@ tls_cert_path: ""
|
||||
tls_key_path: ""
|
||||
|
||||
log:
|
||||
# Valid log levels: panic, fatal, error, warn, info, debug, trace
|
||||
level: info
|
||||
|
||||
# Output formatting for logs: text or json
|
||||
format: text
|
||||
level: info
|
||||
|
||||
## Policy
|
||||
# headscale supports Tailscale's ACL policies.
|
||||
|
||||
@@ -51,11 +51,11 @@ is homelabbers and self-hosters. Of course, we do not prevent people from using
|
||||
it in a commercial/professional setting and often get questions about scaling.
|
||||
|
||||
Please note that when Headscale is developed, performance is not part of the
|
||||
consideration as the main audience is considered to be users with a moddest
|
||||
consideration as the main audience is considered to be users with a modest
|
||||
amount of devices. We focus on correctness and feature parity with Tailscale
|
||||
SaaS over time.
|
||||
|
||||
To understand if you might be able to use Headscale for your usecase, I will
|
||||
To understand if you might be able to use Headscale for your use case, I will
|
||||
describe two scenarios in an effort to explain what is the central bottleneck
|
||||
of Headscale:
|
||||
|
||||
@@ -76,7 +76,7 @@ new "world map" is created for every node in the network.
|
||||
This means that under certain conditions, Headscale can likely handle 100s
|
||||
of devices (maybe more), if there is _little to no change_ happening in the
|
||||
network. For example, in Scenario 1, the process of computing the world map is
|
||||
extremly demanding due to the size of the network, but when the map has been
|
||||
extremely demanding due to the size of the network, but when the map has been
|
||||
created and the nodes are not changing, the Headscale instance will likely
|
||||
return to a very low resource usage until the next time there is an event
|
||||
requiring the new map.
|
||||
@@ -94,14 +94,14 @@ learn about the current state of the world.
|
||||
We expect that the performance will improve over time as we improve the code
|
||||
base, but it is not a focus. In general, we will never make the tradeoff to make
|
||||
things faster on the cost of less maintainable or readable code. We are a small
|
||||
team and have to optimise for maintainabillity.
|
||||
team and have to optimise for maintainability.
|
||||
|
||||
## Which database should I use?
|
||||
|
||||
We recommend the use of SQLite as database for headscale:
|
||||
|
||||
- SQLite is simple to setup and easy to use
|
||||
- It scales well for all of headscale's usecases
|
||||
- It scales well for all of headscale's use cases
|
||||
- Development and testing happens primarily on SQLite
|
||||
- PostgreSQL is still supported, but is considered to be in "maintenance mode"
|
||||
|
||||
|
||||
115
docs/ref/debug.md
Normal file
115
docs/ref/debug.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# Debugging and troubleshooting
|
||||
|
||||
Headscale and Tailscale provide debug and introspection capabilities that can be helpful when things don't work as
|
||||
expected. This page explains some debugging techniques to help pinpoint problems.
|
||||
|
||||
Please also have a look at [Tailscale's Troubleshooting guide](https://tailscale.com/kb/1023/troubleshooting). It offers
|
||||
a many tips and suggestions to troubleshoot common issues.
|
||||
|
||||
## Tailscale
|
||||
|
||||
The Tailscale client itself offers many commands to introspect its state as well as the state of the network:
|
||||
|
||||
- [Check local network conditions](https://tailscale.com/kb/1080/cli#netcheck): `tailscale netcheck`
|
||||
- [Get the client status](https://tailscale.com/kb/1080/cli#status): `tailscale status --json`
|
||||
- [Get DNS status](https://tailscale.com/kb/1080/cli#dns): `tailscale dns status --all`
|
||||
- Client logs: `tailscale debug daemon-logs`
|
||||
- Client netmap: `tailscale debug netmap`
|
||||
- Test DERP connection: `tailscale debug derp headscale`
|
||||
- And many more, see: `tailscale debug --help`
|
||||
|
||||
Many of the commands are helpful when trying to understand differences between Headscale and Tailscale SaaS.
|
||||
|
||||
## Headscale
|
||||
|
||||
### Application logging
|
||||
|
||||
The log levels `debug` and `trace` can be useful to get more information from Headscale.
|
||||
|
||||
```yaml hl_lines="3"
|
||||
log:
|
||||
# Valid log levels: panic, fatal, error, warn, info, debug, trace
|
||||
level: debug
|
||||
```
|
||||
|
||||
### Database logging
|
||||
|
||||
The database debug mode logs all database queries. Enable it to see how Headscale interacts with its database. This also
|
||||
requires the application log level to be set to either `debug` or `trace`.
|
||||
|
||||
```yaml hl_lines="3 7"
|
||||
database:
|
||||
# Enable debug mode. This setting requires the log.level to be set to "debug" or "trace".
|
||||
debug: false
|
||||
|
||||
log:
|
||||
# Valid log levels: panic, fatal, error, warn, info, debug, trace
|
||||
level: debug
|
||||
```
|
||||
|
||||
### Metrics and debug endpoint
|
||||
|
||||
Headscale provides a metrics and debug endpoint. It allows to introspect different aspects such as:
|
||||
|
||||
- Information about the Go runtime, memory usage and statistics
|
||||
- Connected nodes and pending registrations
|
||||
- Active ACLs, filters and SSH policy
|
||||
- Current DERPMap
|
||||
- Prometheus metrics
|
||||
|
||||
!!! warning "Keep the metrics and debug endpoint private"
|
||||
|
||||
The listen address and port can be configured with the `metrics_listen_addr` variable in the [configuration
|
||||
file](./configuration.md). By default it listens on localhost, port 9090.
|
||||
|
||||
Keep the metrics and debug endpoint private to your internal network and don't expose it to the Internet.
|
||||
|
||||
Query metrics via <http://localhost:9090/metrics> and get an overview of available debug information via
|
||||
<http://localhost:9090/debug/>. Metrics may be queried from outside localhost but the debug interface is subject to
|
||||
additional protection despite listening on all interfaces.
|
||||
|
||||
=== "Direct access"
|
||||
|
||||
Access the debug interface directly on the server where Headscale is installed.
|
||||
|
||||
```console
|
||||
curl http://localhost:9090/debug/
|
||||
```
|
||||
|
||||
=== "SSH port forwarding"
|
||||
|
||||
Use SSH port forwarding to forward Headscale's metrics and debug port to your device.
|
||||
|
||||
```console
|
||||
ssh <HEADSCALE_SERVER> -L 9090:localhost:9090
|
||||
```
|
||||
|
||||
Access the debug interface on your device by opening <http://localhost:9090/debug/> in your web browser.
|
||||
|
||||
=== "Via debug key"
|
||||
|
||||
The access control of the debug interface supports the use of a debug key. Traffic is accepted if the path to a
|
||||
debug key is set via the environment variable `TS_DEBUG_KEY_PATH` and the debug key sent as value for `debugkey`
|
||||
parameter with each request.
|
||||
|
||||
```console
|
||||
openssl rand -hex 32 | tee debugkey.txt
|
||||
export TS_DEBUG_KEY_PATH=debugkey.txt
|
||||
headscale serve
|
||||
```
|
||||
|
||||
Access the debug interface on your device by opening `http://<IP_OF_HEADSCALE>:9090/debug/?debugkey=<DEBUG_KEY>` in
|
||||
your web browser. The `debugkey` parameter must be sent with every request.
|
||||
|
||||
=== "Via debug IP address"
|
||||
|
||||
The debug endpoint expects traffic from localhost. A different debug IP address may be configured by setting the
|
||||
`TS_ALLOW_DEBUG_IP` environment variable before starting Headscale. The debug IP address is ignored when the HTTP
|
||||
header `X-Forwarded-For` is present.
|
||||
|
||||
```console
|
||||
export TS_ALLOW_DEBUG_IP=192.168.0.10 # IP address of your device
|
||||
headscale serve
|
||||
```
|
||||
|
||||
Access the debug interface on your device by opening `http://<IP_OF_HEADSCALE>:9090/debug/` in your web browser.
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Headscale supports authentication via external identity providers using OpenID Connect (OIDC). It features:
|
||||
|
||||
- Autoconfiguration via OpenID Connect Discovery Protocol
|
||||
- Auto configuration via OpenID Connect Discovery Protocol
|
||||
- [Proof Key for Code Exchange (PKCE) code verification](#enable-pkce-recommended)
|
||||
- [Authorization based on a user's domain, email address or group membership](#authorize-users-with-filters)
|
||||
- Synchronization of [standard OIDC claims](#supported-oidc-claims)
|
||||
@@ -142,7 +142,7 @@ Access Token.
|
||||
=== "Use expiration from Access Token"
|
||||
|
||||
Please keep in mind that the Access Token is typically a short-lived token that expires within a few minutes. You
|
||||
will have to configure token expiration in your identity provider to avoid frequent reauthentication.
|
||||
will have to configure token expiration in your identity provider to avoid frequent re-authentication.
|
||||
|
||||
|
||||
```yaml hl_lines="5"
|
||||
|
||||
@@ -49,7 +49,7 @@ ID | Hostname | Approved | Available | Serving (Primary)
|
||||
Approve all desired routes of a subnet router by specifying them as comma separated list:
|
||||
|
||||
```console
|
||||
$ headscale nodes approve-routes --node 1 --routes 10.0.0.0/8,192.168.0.0/24
|
||||
$ headscale nodes approve-routes --identifier 1 --routes 10.0.0.0/8,192.168.0.0/24
|
||||
Node updated
|
||||
```
|
||||
|
||||
@@ -175,7 +175,7 @@ ID | Hostname | Approved | Available | Serving (Primary)
|
||||
For exit nodes, it is sufficient to approve either the IPv4 or IPv6 route. The other will be approved automatically.
|
||||
|
||||
```console
|
||||
$ headscale nodes approve-routes --node 1 --routes 0.0.0.0/0
|
||||
$ headscale nodes approve-routes --identifier 1 --routes 0.0.0.0/0
|
||||
Node updated
|
||||
```
|
||||
|
||||
|
||||
@@ -112,11 +112,11 @@ docker exec -it headscale \
|
||||
|
||||
### Register a machine using a pre authenticated key
|
||||
|
||||
Generate a key using the command line:
|
||||
Generate a key using the command line for the user with ID 1:
|
||||
|
||||
```shell
|
||||
docker exec -it headscale \
|
||||
headscale preauthkeys create --user myfirstuser --reusable --expiration 24h
|
||||
headscale preauthkeys create --user 1 --reusable --expiration 24h
|
||||
```
|
||||
|
||||
This will return a pre-authenticated key that can be used to connect a node to headscale with the `tailscale up` command:
|
||||
|
||||
@@ -117,14 +117,14 @@ headscale instance. By default, the key is valid for one hour and can only be us
|
||||
=== "Native"
|
||||
|
||||
```shell
|
||||
headscale preauthkeys create --user <USER>
|
||||
headscale preauthkeys create --user <USER_ID>
|
||||
```
|
||||
|
||||
=== "Container"
|
||||
|
||||
```shell
|
||||
docker exec -it headscale \
|
||||
headscale preauthkeys create --user <USER>
|
||||
headscale preauthkeys create --user <USER_ID>
|
||||
```
|
||||
|
||||
The command returns the preauthkey on success which is used to connect a node to the headscale instance via the
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
overlay = _: prev: let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.system};
|
||||
buildGo = pkgs.buildGo124Module;
|
||||
vendorHash = "sha256-S2GnCg2dyfjIyi5gXhVEuRs5Bop2JAhZcnhg1fu4/Gg=";
|
||||
vendorHash = "sha256-83L2NMyOwKCHWqcowStJ7Ze/U9CJYhzleDRLrJNhX2g=";
|
||||
in {
|
||||
headscale = buildGo {
|
||||
pname = "headscale";
|
||||
|
||||
@@ -913,10 +913,6 @@ func (x *RenameNodeResponse) GetNode() *Node {
|
||||
type ListNodesRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"`
|
||||
Id uint64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"`
|
||||
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`
|
||||
Hostname string `protobuf:"bytes,4,opt,name=hostname,proto3" json:"hostname,omitempty"`
|
||||
IpAddresses []string `protobuf:"bytes,5,rep,name=ip_addresses,json=ipAddresses,proto3" json:"ip_addresses,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -958,34 +954,6 @@ func (x *ListNodesRequest) GetUser() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ListNodesRequest) GetId() uint64 {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *ListNodesRequest) GetName() string {
|
||||
if x != nil {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ListNodesRequest) GetHostname() string {
|
||||
if x != nil {
|
||||
return x.Hostname
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ListNodesRequest) GetIpAddresses() []string {
|
||||
if x != nil {
|
||||
return x.IpAddresses
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ListNodesResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Nodes []*Node `protobuf:"bytes,1,rep,name=nodes,proto3" json:"nodes,omitempty"`
|
||||
@@ -1390,13 +1358,9 @@ const file_headscale_v1_node_proto_rawDesc = "" +
|
||||
"\anode_id\x18\x01 \x01(\x04R\x06nodeId\x12\x19\n" +
|
||||
"\bnew_name\x18\x02 \x01(\tR\anewName\"<\n" +
|
||||
"\x12RenameNodeResponse\x12&\n" +
|
||||
"\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"\x89\x01\n" +
|
||||
"\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"&\n" +
|
||||
"\x10ListNodesRequest\x12\x12\n" +
|
||||
"\x04user\x18\x01 \x01(\tR\x04user\x12\x0e\n" +
|
||||
"\x02id\x18\x02 \x01(\x04R\x02id\x12\x12\n" +
|
||||
"\x04name\x18\x03 \x01(\tR\x04name\x12\x1a\n" +
|
||||
"\bhostname\x18\x04 \x01(\tR\bhostname\x12!\n" +
|
||||
"\fip_addresses\x18\x05 \x03(\tR\vipAddresses\"=\n" +
|
||||
"\x04user\x18\x01 \x01(\tR\x04user\"=\n" +
|
||||
"\x11ListNodesResponse\x12(\n" +
|
||||
"\x05nodes\x18\x01 \x03(\v2\x12.headscale.v1.NodeR\x05nodes\">\n" +
|
||||
"\x0fMoveNodeRequest\x12\x17\n" +
|
||||
|
||||
@@ -187,35 +187,6 @@
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
"format": "uint64"
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "hostname",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "ipAddresses",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"collectionFormat": "multi"
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
|
||||
27
go.mod
27
go.mod
@@ -14,7 +14,7 @@ require (
|
||||
github.com/creachadair/command v0.1.22
|
||||
github.com/creachadair/flax v0.0.5
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
|
||||
github.com/docker/docker v28.2.2+incompatible
|
||||
github.com/docker/docker v28.3.3+incompatible
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/go-gormigrate/gormigrate/v2 v2.1.4
|
||||
@@ -23,7 +23,6 @@ require (
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0
|
||||
github.com/jagottsicher/termcolor v1.0.2
|
||||
github.com/klauspost/compress v1.18.0
|
||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||
github.com/ory/dockertest/v3 v3.12.0
|
||||
github.com/philip-bui/grpc-zerolog v1.0.1
|
||||
@@ -43,11 +42,11 @@ require (
|
||||
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
|
||||
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/crypto v0.40.0
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
||||
golang.org/x/net v0.41.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.15.0
|
||||
golang.org/x/sync v0.16.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822
|
||||
google.golang.org/grpc v1.73.0
|
||||
google.golang.org/protobuf v1.36.6
|
||||
@@ -55,7 +54,7 @@ require (
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.30.0
|
||||
tailscale.com v1.84.2
|
||||
tailscale.com v1.84.3
|
||||
zgo.at/zcache/v2 v2.2.0
|
||||
zombiezen.com/go/postgrestest v1.0.1
|
||||
)
|
||||
@@ -166,6 +165,7 @@ require (
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
|
||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
@@ -231,14 +231,19 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/term v0.32.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/mod v0.26.0 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/term v0.33.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.10.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect
|
||||
)
|
||||
|
||||
tool (
|
||||
golang.org/x/tools/cmd/stringer
|
||||
tailscale.com/cmd/viewer
|
||||
)
|
||||
|
||||
40
go.sum
40
go.sum
@@ -148,8 +148,8 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c=
|
||||
github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0=
|
||||
github.com/docker/cli v28.1.1+incompatible h1:eyUemzeI45DY7eDPuwUcmDyDj1pM98oD5MdSpiItp8k=
|
||||
github.com/docker/cli v28.1.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
|
||||
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
|
||||
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/docker v28.3.3+incompatible h1:Dypm25kh4rmk49v1eiVbsAtpAsYURjYkaKubwuBdxEI=
|
||||
github.com/docker/docker v28.3.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
@@ -555,8 +555,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||
@@ -567,8 +567,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
@@ -577,8 +577,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -587,8 +587,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -615,8 +615,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -624,8 +624,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
|
||||
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
|
||||
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
|
||||
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
@@ -633,8 +633,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -643,8 +643,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -712,8 +712,8 @@ modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
|
||||
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0=
|
||||
tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
||||
tailscale.com v1.84.3 h1:Ur9LMedSgicwbqpy5xn7t49G8490/s6rqAJOk5Q5AYE=
|
||||
tailscale.com v1.84.3/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
||||
zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo=
|
||||
zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
||||
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
||||
|
||||
130
hscontrol/app.go
130
hscontrol/app.go
@@ -28,14 +28,15 @@ import (
|
||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||
"github.com/juanfont/headscale/hscontrol/dns"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/pkg/profile"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -64,6 +65,19 @@ var (
|
||||
)
|
||||
)
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
)
|
||||
|
||||
func init() {
|
||||
deadlock.Opts.Disable = !debugDeadlock
|
||||
if debugDeadlock {
|
||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
AuthPrefix = "Bearer "
|
||||
updateInterval = 5 * time.Second
|
||||
@@ -82,9 +96,8 @@ type Headscale struct {
|
||||
|
||||
// Things that generate changes
|
||||
extraRecordMan *dns.ExtraRecordsMan
|
||||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
authProvider AuthProvider
|
||||
mapBatcher mapper.Batcher
|
||||
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
}
|
||||
@@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
cfg: cfg,
|
||||
noisePrivateKey: noisePrivateKey,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
nodeNotifier: notifier.NewNotifier(cfg),
|
||||
state: s,
|
||||
}
|
||||
|
||||
@@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
|
||||
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
app.Change(policyChanged)
|
||||
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
||||
})
|
||||
app.ephemeralGC = ephemeralGC
|
||||
@@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
defer cancel()
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
cfg.ServerURL,
|
||||
&cfg.OIDC,
|
||||
app.state,
|
||||
app.nodeNotifier,
|
||||
)
|
||||
if err != nil {
|
||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||
@@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var update types.StateUpdate
|
||||
var expiredNodeChanges []change.ChangeSet
|
||||
var changed bool
|
||||
|
||||
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
if changed {
|
||||
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
||||
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
// Send the changes directly since they're already in the new format
|
||||
for _, nodeChange := range expiredNodeChanges {
|
||||
h.Change(nodeChange)
|
||||
}
|
||||
}
|
||||
|
||||
case <-derpTickerChan:
|
||||
@@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StateDERPUpdated,
|
||||
DERPMap: derpMap,
|
||||
})
|
||||
h.Change(change.DERPSet)
|
||||
|
||||
case records, ok := <-extraRecordsUpdate:
|
||||
if !ok {
|
||||
@@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
}
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
|
||||
// TODO(kradalby): We can probably do better than sending a full update here,
|
||||
// but for now this will ensure that all of the nodes get the new records.
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
h.Change(change.ExtraRecordsSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
req interface{},
|
||||
req any,
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
) (any, error) {
|
||||
// Check if the request is coming from the on-server client.
|
||||
// This is not secure, but it is to maintain maintainability
|
||||
// with the "legacy" database-based client
|
||||
@@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
return router
|
||||
}
|
||||
|
||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// // Maybe this should be implemented as an event bus?
|
||||
// // A bool is returned indicating if a full update was sent to all nodes
|
||||
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||
// users, err := db.ListUsers()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// changed, err := polMan.SetUsers(users)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// if changed {
|
||||
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
|
||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// // Maybe this should be implemented as an event bus?
|
||||
// // A bool is returned indicating if a full update was sent to all nodes
|
||||
// func nodesChangedHook(
|
||||
// db *db.HSDatabase,
|
||||
// polMan policy.PolicyManager,
|
||||
// notif *notifier.Notifier,
|
||||
// ) (bool, error) {
|
||||
// nodes, err := db.ListNodes()
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
|
||||
// filterChanged, err := polMan.SetNodes(nodes)
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
|
||||
// if filterChanged {
|
||||
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
||||
|
||||
// return true, nil
|
||||
// }
|
||||
|
||||
// return false, nil
|
||||
// }
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
@@ -562,8 +511,9 @@ func (h *Headscale) Serve() error {
|
||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
// TODO(kradalby): fix state part.
|
||||
if h.cfg.DERP.ServerEnabled {
|
||||
@@ -838,8 +788,12 @@ func (h *Headscale) Serve() error {
|
||||
log.Info().
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
err = h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
||||
}
|
||||
|
||||
h.Change(change.PolicySet)
|
||||
}
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
@@ -865,7 +819,6 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
info("closing node notifier")
|
||||
h.nodeNotifier.Close()
|
||||
|
||||
info("waiting for netmap stream to close")
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
@@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
|
||||
return &machineKey, nil
|
||||
}
|
||||
|
||||
// Change is used to send changes to nodes.
|
||||
// All change should be enqueued here and empty will be automatically
|
||||
// ignored.
|
||||
func (h *Headscale) Change(c change.ChangeSet) {
|
||||
h.mapBatcher.AddWork(c)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@@ -32,6 +34,21 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
if node != nil {
|
||||
// If an existing node is trying to register with an auth key,
|
||||
// we need to validate the auth key even for existing nodes
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
@@ -47,6 +64,11 @@ func (h *Headscale) handleRegister(
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||
}
|
||||
|
||||
@@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
if node.MachineKey != machineKey {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
|
||||
if !expired && !regReq.Expiry.IsZero() {
|
||||
requestExpiry := regReq.Expiry
|
||||
|
||||
@@ -82,42 +106,26 @@ func (h *Headscale) handleExistingNode(
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
policyChanged, err := h.state.DeleteNode(node)
|
||||
c, err := h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
}
|
||||
h.Change(c)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
|
||||
h.Change(c)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(n), nil
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
@@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
)
|
||||
@@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If node is nil, it means an ephemeral node was deleted during logout
|
||||
if node == nil {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
// dependency here.
|
||||
// Because the way the policy manager works, we need to have the node
|
||||
@@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := h.state.AutoApproveRoutes(node)
|
||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||
// now since we dont update the node/pol here anymore
|
||||
routeChange := h.state.AutoApproveRoutes(node)
|
||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
||||
} else if changed {
|
||||
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
// Existing node re-registering without route changes
|
||||
// Still need to notify peers about the node being active again
|
||||
// Use UpdateFull to ensure all peers get complete peer maps
|
||||
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
if routeChange && changed.Empty() {
|
||||
changed = change.NodeAdded(node.ID)
|
||||
}
|
||||
h.Change(changed)
|
||||
|
||||
// If policy changed due to node registration, send a separate policy change
|
||||
if policyChanged {
|
||||
policyChange := change.PolicyChange()
|
||||
h.Change(policyChange)
|
||||
}
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package capver
|
||||
|
||||
//go:generate go run ../../tools/capver/main.go
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88
|
||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90
|
||||
|
||||
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
|
||||
// there are old code that can ble cleaned up, entries should contain
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
package capver
|
||||
|
||||
// Generated DO NOT EDIT
|
||||
//Generated DO NOT EDIT
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.60.0": 87,
|
||||
"v1.60.1": 87,
|
||||
"v1.62.0": 88,
|
||||
"v1.62.1": 88,
|
||||
"v1.64.0": 90,
|
||||
"v1.64.1": 90,
|
||||
"v1.64.2": 90,
|
||||
@@ -36,18 +32,21 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.80.3": 113,
|
||||
"v1.82.0": 115,
|
||||
"v1.82.5": 115,
|
||||
"v1.84.0": 116,
|
||||
"v1.84.1": 116,
|
||||
"v1.84.2": 116,
|
||||
}
|
||||
|
||||
|
||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
116: "v1.84.0",
|
||||
}
|
||||
|
||||
@@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
||||
stripV bool
|
||||
expected []string
|
||||
}{
|
||||
{3, false, []string{"v1.78", "v1.80", "v1.82"}},
|
||||
{2, true, []string{"1.80", "1.82"}},
|
||||
{3, false, []string{"v1.80", "v1.82", "v1.84"}},
|
||||
{2, true, []string{"1.82", "1.84"}},
|
||||
// Lazy way to see all supported versions
|
||||
{10, true, []string{
|
||||
"1.64",
|
||||
"1.66",
|
||||
"1.68",
|
||||
"1.70",
|
||||
@@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
||||
"1.78",
|
||||
"1.80",
|
||||
"1.82",
|
||||
"1.84",
|
||||
}},
|
||||
{0, false, nil},
|
||||
}
|
||||
@@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) {
|
||||
input tailcfg.CapabilityVersion
|
||||
expected string
|
||||
}{
|
||||
{88, "v1.62.0"},
|
||||
{90, "v1.64.0"},
|
||||
{95, "v1.66.0"},
|
||||
{106, "v1.74.0"},
|
||||
|
||||
@@ -496,7 +496,7 @@ func NewHeadscaleDatabase(
|
||||
ID: "202407191627",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Fix an issue where the automigration in GORM expected a constraint to
|
||||
// exists that didnt, and add the one it wanted.
|
||||
// exists that didn't, and add the one it wanted.
|
||||
// Fixes https://github.com/juanfont/headscale/issues/2351
|
||||
if cfg.Type == types.DatabasePostgres {
|
||||
err := tx.Exec(`
|
||||
@@ -934,7 +934,7 @@ AND auth_key_id NOT IN (
|
||||
},
|
||||
// From this point, the following rules must be followed:
|
||||
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
|
||||
// - AutoMigrate depends on the struct staying exactly the same, which it wont over time.
|
||||
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
|
||||
// - Never write migrations that requires foreign keys to be disabled.
|
||||
},
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
|
||||
sort.Sort(sort.StringSlice(a))
|
||||
sort.Sort(sort.StringSlice(b))
|
||||
slices.Sort(a)
|
||||
slices.Sort(b)
|
||||
return slices.Equal(a, b)
|
||||
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
|
||||
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)
|
||||
|
||||
@@ -7,15 +7,19 @@ import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,9 +43,7 @@ var (
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListPeers(rx, nodeID, peerIDs...)
|
||||
})
|
||||
return ListPeers(hsdb.DB, nodeID, peerIDs...)
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
@@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodes(rx, nodeIDs...)
|
||||
})
|
||||
return ListNodes(hsdb.DB, nodeIDs...)
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
@@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByID(rx, id)
|
||||
})
|
||||
return GetNodeByID(hsdb.DB, id)
|
||||
}
|
||||
|
||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||
@@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByMachineKey(rx, machineKey)
|
||||
})
|
||||
return GetNodeByMachineKey(hsdb.DB, machineKey)
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||
@@ -163,9 +159,7 @@ func GetNodeByMachineKey(
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByNodeKey(rx, nodeKey)
|
||||
})
|
||||
return GetNodeByNodeKey(hsdb.DB, nodeKey)
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
|
||||
@@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, bool, error) {
|
||||
var newNode bool
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
var nodeChange change.ChangeSet
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
@@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
}
|
||||
close(reg.Registered)
|
||||
|
||||
newNode = true
|
||||
nodeChange = change.NodeAdded(node.ID)
|
||||
|
||||
return node, err
|
||||
} else {
|
||||
@@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeChange = change.KeyExpiry(node.ID)
|
||||
|
||||
return node, nil
|
||||
}
|
||||
}
|
||||
@@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, newNode, err
|
||||
return node, nodeChange, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
@@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
if oldNode != nil && oldNode.UserID == node.UserID {
|
||||
node.ID = oldNode.ID
|
||||
node.GivenName = oldNode.GivenName
|
||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||
ipv4 = oldNode.IPv4
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
@@ -594,17 +591,18 @@ func ensureUniqueGivenName(
|
||||
// containing the expired nodes, and a boolean indicating if any nodes were found.
|
||||
func ExpireExpiredNodes(tx *gorm.DB,
|
||||
lastCheck time.Time,
|
||||
) (time.Time, types.StateUpdate, bool) {
|
||||
) (time.Time, []change.ChangeSet, bool) {
|
||||
// use the time of the start of the function to ensure we
|
||||
// dont miss some nodes by returning it _after_ we have
|
||||
// checked everything.
|
||||
started := time.Now()
|
||||
|
||||
expired := make([]*tailcfg.PeerChange, 0)
|
||||
var updates []change.ChangeSet
|
||||
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0), types.StateUpdate{}, false
|
||||
return time.Unix(0, 0), nil, false
|
||||
}
|
||||
for _, node := range nodes {
|
||||
if node.IsExpired() && node.Expiry.After(lastCheck) {
|
||||
@@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB,
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: node.Expiry,
|
||||
})
|
||||
updates = append(updates, change.KeyExpiry(node.ID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(expired) > 0 {
|
||||
return started, types.UpdatePeerPatch(expired...), true
|
||||
return started, updates, true
|
||||
}
|
||||
|
||||
return started, types.StateUpdate{}, false
|
||||
return started, nil, false
|
||||
}
|
||||
|
||||
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
|
||||
@@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
|
||||
// Create a preauth key for the node
|
||||
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
|
||||
}
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
node := &types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
DiscoKey: discoKey.Public(),
|
||||
Hostname: nodeName,
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
|
||||
err = hsdb.DB.Save(node).Error
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create test node: %v", err))
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateRegisteredNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||
|
||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, *node, nil, nil)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||
}
|
||||
|
||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
||||
}
|
||||
|
||||
return registeredNode
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateNodesForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
nodes := make([]*types.Node, count)
|
||||
for i := range count {
|
||||
hostname := prefix + "-" + strconv.Itoa(i)
|
||||
nodes[i] = hsdb.CreateNodeForTest(user, hostname)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateRegisteredNodesForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
nodes := make([]*types.Node, count)
|
||||
for i := range count {
|
||||
hostname := prefix + "-" + strconv.Itoa(i)
|
||||
nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -26,82 +25,36 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
_, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
node := db.CreateNodeForTest(user, "testnode")
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(node.Hostname, check.Equals, "testnode")
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
node := db.CreateNodeForTest(user, "testnode")
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
retrievedNode, err := db.GetNodeByID(node.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(retrievedNode.Hostname, check.Equals, "testnode")
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
node := db.CreateNodeForTest(user, "testnode3")
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode3",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
err = db.DeleteNode(&node)
|
||||
err := db.DeleteNode(node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||
@@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := range 11 {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
nodes := db.CreateNodesForTest(user, 11, "testnode")
|
||||
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
}
|
||||
|
||||
node0ByID, err := db.GetNodeByID(0)
|
||||
firstNode := nodes[0]
|
||||
peersOfFirstNode, err := db.ListPeers(firstNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfNode0, err := db.ListPeers(node0ByID.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peersOfNode0), check.Equals, 9)
|
||||
c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2")
|
||||
c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7")
|
||||
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
|
||||
c.Assert(len(peersOfFirstNode), check.Equals, 10)
|
||||
c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1")
|
||||
c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6")
|
||||
c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10")
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireNode(c *check.C) {
|
||||
@@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) {
|
||||
// No parameter means no filter, should return all peers
|
||||
nodes, err = db.ListPeers(1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Empty node list should return all peers
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// No match in IDs should return empty list and no error
|
||||
@@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs, but node ID is still filtered out
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
}
|
||||
|
||||
@@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) {
|
||||
// No parameter means no filter, should return all nodes
|
||||
nodes, err = db.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
// Empty node list should return all nodes
|
||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
@@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
}
|
||||
|
||||
@@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return GetPreAuthKey(rx, key)
|
||||
})
|
||||
return GetPreAuthKey(hsdb.DB, key)
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
||||
@@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||
}
|
||||
|
||||
func generateKey() (string, error) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||
sort.Sort(sort.StringSlice(gotTags))
|
||||
slices.Sort(gotTags)
|
||||
c.Assert(gotTags, check.DeepEquals, tags)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package db
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUserByID(rx, uid)
|
||||
})
|
||||
return GetUserByID(hsdb.DB, uid)
|
||||
}
|
||||
|
||||
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
||||
@@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||
return ListUsers(rx, where...)
|
||||
})
|
||||
return ListUsers(hsdb.DB, where...)
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
@@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
|
||||
if !testing.Testing() {
|
||||
panic("CreateUserForTest can only be called during tests")
|
||||
}
|
||||
|
||||
userName := "testuser"
|
||||
if len(name) > 0 && name[0] != "" {
|
||||
userName = name[0]
|
||||
}
|
||||
|
||||
user, err := hsdb.CreateUser(types.User{Name: userName})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create test user: %v", err))
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
|
||||
if !testing.Testing() {
|
||||
panic("CreateUsersForTest can only be called during tests")
|
||||
}
|
||||
|
||||
prefix := "testuser"
|
||||
if len(namePrefix) > 0 && namePrefix[0] != "" {
|
||||
prefix = namePrefix[0]
|
||||
}
|
||||
|
||||
users := make([]*types.User, count)
|
||||
for i := range count {
|
||||
name := prefix + "-" + strconv.Itoa(i)
|
||||
users[i] = hsdb.CreateUserForTest(name)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
@@ -11,8 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
c.Assert(user.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
@@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
err := db.DestroyUser(9998)
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
@@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameUser(c *check.C) {
|
||||
userTest, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
userTest := db.CreateUserForTest("test")
|
||||
c.Assert(userTest.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
@@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||
err = db.RenameUser(99988, "test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
userTest2, err := db.CreateUser(types.User{Name: "test2"})
|
||||
c.Assert(err, check.IsNil)
|
||||
userTest2 := db.CreateUserForTest("test2")
|
||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||
|
||||
want := "UNIQUE constraint failed"
|
||||
@@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := db.CreateUser(types.User{Name: "old"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
newUser, err := db.CreateUser(types.User{Name: "new"})
|
||||
c.Assert(err, check.IsNil)
|
||||
oldUser := db.CreateUserForTest("old")
|
||||
newUser := db.CreateUserForTest("new")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
@@ -17,10 +17,6 @@ import (
|
||||
func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
debugMux := http.NewServeMux()
|
||||
debug := tsweb.Debugger(debugMux)
|
||||
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.nodeNotifier.String()))
|
||||
}))
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config, err := json.MarshalIndent(h.cfg, "", " ")
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
||||
}
|
||||
|
||||
for _, derpMap := range derpMaps {
|
||||
for id, region := range derpMap.Regions {
|
||||
result.Regions[id] = region
|
||||
}
|
||||
maps.Copy(result.Regions, derpMap.Regions)
|
||||
}
|
||||
|
||||
return &result
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto
|
||||
|
||||
// nolint
|
||||
package hscontrol
|
||||
|
||||
@@ -27,6 +29,7 @@ import (
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
@@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
|
||||
c := change.UserAdded(types.UserID(user.ID))
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
c.Change = change.Policy
|
||||
}
|
||||
|
||||
api.h.Change(c)
|
||||
|
||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
|
||||
@@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser(
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||
@@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
|
||||
|
||||
return &v1.DeleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
return nil, fmt.Errorf("looking up user: %w", err)
|
||||
}
|
||||
|
||||
node, _, err := api.h.state.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := api.h.state.HandleNodeFromAuthPath(
|
||||
registrationId,
|
||||
types.UserID(user.ID),
|
||||
nil,
|
||||
@@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := api.h.state.AutoApproveRoutes(node)
|
||||
_, policyChanged, err := api.h.state.SaveNode(node)
|
||||
_ = api.h.state.AutoApproveRoutes(node)
|
||||
_, _, err = api.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
||||
}
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
|
||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
||||
|
||||
return &v1.GetNodeResponse{Node: resp}, nil
|
||||
}
|
||||
@@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags(
|
||||
}
|
||||
}
|
||||
|
||||
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// Always propagate node changes from SetApprovedRoutes
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
// If routes changed, propagate those changes too
|
||||
if !routeChange.Empty() {
|
||||
api.h.Change(routeChange)
|
||||
}
|
||||
|
||||
proto := node.Proto()
|
||||
@@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policyChanged, err := api.h.state.DeleteNode(node)
|
||||
nodeChange, err := api.h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.DeleteNodeResponse{}, nil
|
||||
}
|
||||
@@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
) (*v1.ExpireNodeResponse, error) {
|
||||
now := time.Now()
|
||||
|
||||
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
|
||||
// TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
ctx context.Context,
|
||||
request *v1.RenameNodeRequest,
|
||||
) (*v1.RenameNodeResponse, error) {
|
||||
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||
node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// TODO(kradalby): investigate if we need selfupdate
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@@ -493,87 +451,48 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
ctx context.Context,
|
||||
request *v1.ListNodesRequest,
|
||||
) (*v1.ListNodesResponse, error) {
|
||||
var nodes types.Nodes
|
||||
var err error
|
||||
// TODO(kradalby): it looks like this can be simplified a lot,
|
||||
// the filtering of nodes by user, vs nodes as a whole can
|
||||
// probably be done once.
|
||||
// TODO(kradalby): This should be done in one tx.
|
||||
|
||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||
IsConnected := api.h.mapBatcher.ConnectedMap()
|
||||
if request.GetUser() != "" {
|
||||
user, err := api.h.state.GetUserByName(request.GetUser())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start with all nodes and apply filters
|
||||
nodes, err = api.h.state.ListNodes()
|
||||
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
nodes, err := api.h.state.ListNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply filters based on request
|
||||
nodes = api.filterNodes(nodes, request)
|
||||
|
||||
sort.Slice(nodes, func(i, j int) bool {
|
||||
return nodes[i].ID < nodes[j].ID
|
||||
})
|
||||
|
||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
// filterNodes applies the filters from ListNodesRequest to the node list
|
||||
func (api headscaleV1APIServer) filterNodes(nodes types.Nodes, request *v1.ListNodesRequest) types.Nodes {
|
||||
var filtered types.Nodes
|
||||
|
||||
for _, node := range nodes {
|
||||
// Filter by user
|
||||
if request.GetUser() != "" && node.User.Name != request.GetUser() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by ID (backward compatibility)
|
||||
if request.GetId() != 0 && uint64(node.ID) != request.GetId() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by name (exact match)
|
||||
if request.GetName() != "" && node.Hostname != request.GetName() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by hostname (alias for name)
|
||||
if request.GetHostname() != "" && node.Hostname != request.GetHostname() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by IP addresses
|
||||
if len(request.GetIpAddresses()) > 0 {
|
||||
hasMatchingIP := false
|
||||
for _, requestIP := range request.GetIpAddresses() {
|
||||
for _, nodeIP := range node.IPs() {
|
||||
if nodeIP.String() == requestIP {
|
||||
hasMatchingIP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasMatchingIP {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasMatchingIP {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, node matches all filters
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
response := make([]*v1.Node, len(nodes))
|
||||
for index, node := range nodes {
|
||||
resp := node.Proto()
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
|
||||
if val, ok := IsConnected.Load(node.ID); ok && val {
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
@@ -595,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode(
|
||||
ctx context.Context,
|
||||
request *v1.MoveNodeRequest,
|
||||
) (*v1.MoveNodeResponse, error) {
|
||||
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||
node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// TODO(kradalby): Ensure the policy is also sent
|
||||
// TODO(kradalby): ensure that both the selfupdate and peer updates are sent
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@@ -793,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
response := &v1.SetPolicyResponse{
|
||||
|
||||
155
hscontrol/mapper/batcher.go
Normal file
155
hscontrol/mapper/batcher.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
|
||||
// Batcher defines the common interface for all batcher implementations.
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
}
|
||||
|
||||
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
||||
return &LockFreeBatcher{
|
||||
mapper: mapper,
|
||||
workers: workers,
|
||||
tick: time.NewTicker(batchTime),
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBatcherAndMapper creates a Batcher implementation.
|
||||
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
return b
|
||||
}
|
||||
|
||||
// nodeConnection interface for different connection implementations.
|
||||
type nodeConnection interface {
|
||||
nodeID() types.NodeID
|
||||
version() tailcfg.CapabilityVersion
|
||||
send(data *tailcfg.MapResponse) error
|
||||
}
|
||||
|
||||
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
|
||||
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
if c.Empty() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Validate inputs before processing
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
mapResp, err = mapper.derpMapResponse(nodeID)
|
||||
|
||||
case change.NodeCameOnline, change.NodeWentOffline:
|
||||
if c.IsSubnetRouter {
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case change.NodeNewOrUpdate:
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
|
||||
case change.NodeRemove:
|
||||
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
|
||||
|
||||
default:
|
||||
// The following will always hit this:
|
||||
// change.Full, change.Policy
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this necessary?
|
||||
// Validate the generated map response - only check for nil response
|
||||
// Note: mapResp.Node can be nil for peer updates, which is valid
|
||||
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
|
||||
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
|
||||
}
|
||||
|
||||
return mapResp, nil
|
||||
}
|
||||
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
// No data to send is valid for some change types
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// workResult represents the result of processing a change.
|
||||
type workResult struct {
|
||||
mapResponse *tailcfg.MapResponse
|
||||
err error
|
||||
}
|
||||
|
||||
// work represents a unit of work to be processed by workers.
|
||||
type work struct {
|
||||
c change.ChangeSet
|
||||
nodeID types.NodeID
|
||||
resultCh chan<- workResult // optional channel for synchronous operations
|
||||
}
|
||||
491
hscontrol/mapper/batcher_lockfree.go
Normal file
491
hscontrol/mapper/batcher_lockfree.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
totalUpdates atomic.Int64
|
||||
workQueuedCount atomic.Int64
|
||||
workProcessed atomic.Int64
|
||||
workErrors atomic.Int64
|
||||
}
|
||||
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
|
||||
var conn *nodeConn
|
||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||
// Update existing connection
|
||||
existing.updateConnection(c, version)
|
||||
conn = existing
|
||||
} else {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
if err := conn.send(initialMap); err != nil {
|
||||
// Clean up the connection state on send failure
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-b.workCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
// where there is a blocking function waiting for the map that
|
||||
// is being generated.
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Msg("node not found for synchronous work")
|
||||
}
|
||||
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If resultCh is nil, this is an asynchronous work request
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Check if this connection is still active before processing
|
||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
continue
|
||||
}
|
||||
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.c.NodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow asynchronous work processing")
|
||||
}
|
||||
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
// For critical changes that need immediate processing, send directly
|
||||
if b.shouldProcessImmediately(c) {
|
||||
if c.SelfUpdateOnly {
|
||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||
return
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For non-critical changes, add to batch
|
||||
b.addToBatch(c)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.ctx.Done():
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(c.NodeID, changes)
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process all pending changes
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
|
||||
if len(changes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Send all batched changes for this node
|
||||
for _, c := range changes {
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// nil means connected
|
||||
ret.Store(id, val == nil)
|
||||
return true
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||
// This allows synchronous map generation using the same worker pool.
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
resultCh := make(chan workResult, 1)
|
||||
|
||||
// Queue the work with a result channel using the safe queueing method
|
||||
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
|
||||
|
||||
// Wait for the result
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.ctx.Done():
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
1977
hscontrol/mapper/batcher_test.go
Normal file
1977
hscontrol/mapper/batcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
259
hscontrol/mapper/builder.go
Normal file
259
hscontrol/mapper/builder.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
},
|
||||
mapper: m,
|
||||
nodeID: nodeID,
|
||||
errs: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
filter, _ := b.mapper.state.Filter()
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.nodeID)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
}
|
||||
347
hscontrol/mapper/builder_test.go
Normal file
347
hscontrol/mapper/builder_test.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
BaseDomain: "example.com",
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
|
||||
// Test basic builder creation
|
||||
assert.NotNil(t, builder)
|
||||
assert.Equal(t, nodeID, builder.nodeID)
|
||||
assert.NotNil(t, builder.resp)
|
||||
assert.False(t, builder.resp.KeepAlive)
|
||||
assert.NotNil(t, builder.resp.ControlTime)
|
||||
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(42)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer)
|
||||
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||
domain := "test.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://test.example.com",
|
||||
BaseDomain: domain,
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDomain()
|
||||
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logTailEnabled bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "LogTail enabled",
|
||||
logTailEnabled: true,
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
},
|
||||
{
|
||||
name: "LogTail disabled",
|
||||
logTailEnabled: false,
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: tt.logTailEnabled,
|
||||
},
|
||||
}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugConfig()
|
||||
|
||||
require.NotNil(t, builder.resp.Debug)
|
||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
changes := []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 123,
|
||||
DERPRegion: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 456,
|
||||
DERPRegion: 2,
|
||||
},
|
||||
}
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changes)
|
||||
|
||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(123)
|
||||
removedID2 := types.NodeID(456)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1, removedID2)
|
||||
|
||||
expected := []tailcfg.NodeID{
|
||||
removedID1.NodeID(),
|
||||
removedID2.NodeID(),
|
||||
}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Simulate an error in the builder
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
|
||||
// All subsequent calls should continue to work and accumulate errors
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 1)
|
||||
assert.Equal(t, assert.AnError, result.errs[0])
|
||||
|
||||
// Build should return the error
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||
domain := "chained.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://chained.example.com",
|
||||
BaseDomain: domain,
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(99)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.NotNil(t, builder.resp.Debug)
|
||||
assert.True(t, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(100)
|
||||
removedID2 := types.NodeID(200)
|
||||
|
||||
// Test calling WithPeersRemoved multiple times
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1).
|
||||
WithPeersRemoved(removedID2)
|
||||
|
||||
// Second call should overwrite the first
|
||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||
|
||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(nil)
|
||||
|
||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Create a builder and add multiple errors
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(nil) // This should be ignored
|
||||
|
||||
// All subsequent calls should continue to work
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
@@ -10,31 +9,21 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
reservedResponseHeaderSize = 4
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||
@@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
|
||||
// - Create a "minifier" that removes info not needed for the node
|
||||
// - some sort of batching, wait for 5 or 60 seconds before sending
|
||||
|
||||
type Mapper struct {
|
||||
type mapper struct {
|
||||
// Configuration
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
notif *notifier.Notifier
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
batcher Batcher
|
||||
|
||||
uid string
|
||||
created time.Time
|
||||
seq uint64
|
||||
}
|
||||
|
||||
type patch struct {
|
||||
@@ -66,41 +53,31 @@ type patch struct {
|
||||
change *tailcfg.PeerChange
|
||||
}
|
||||
|
||||
func NewMapper(
|
||||
state *state.State,
|
||||
func newMapper(
|
||||
cfg *types.Config,
|
||||
notif *notifier.Notifier,
|
||||
) *Mapper {
|
||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &Mapper{
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
notif: notif,
|
||||
|
||||
uid: uid,
|
||||
created: time.Now(),
|
||||
seq: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mapper) String() string {
|
||||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, peers.Len()+1)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@@ -117,7 +94,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node types.NodeView,
|
||||
node *types.Node,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@@ -137,17 +114,16 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
}
|
||||
|
||||
nodeIPs := node.IPs()
|
||||
if len(nodeIPs) > 0 {
|
||||
attrs.Add("device_ip", nodeIPs[0].String())
|
||||
if len(node.IPs()) > 0 {
|
||||
attrs.Add("device_ip", node.IPs()[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
@@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
}
|
||||
|
||||
// fullMapResponse creates a complete MapResponse for a node.
|
||||
// It is a separate function to make testing easier.
|
||||
func (m *Mapper) fullMapResponse(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
peers, err := m.listPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
resp,
|
||||
true, // full change
|
||||
m.state,
|
||||
node,
|
||||
capVer,
|
||||
peers,
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
Build(messages...)
|
||||
}
|
||||
|
||||
// FullMapResponse returns a MapResponse for the given node.
|
||||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
peers, err := m.ListPeers(node.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
// ReadOnlyMapResponse returns a MapResponse for the given node.
|
||||
// Lite means that the peers has been omitted, this is intended
|
||||
// to be used to answer MapRequests with OmitPeers set to true.
|
||||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.KeepAlive = true
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) DERPMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.DERPMap = derpMap
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) PeerChangedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
changed map[types.NodeID]bool,
|
||||
patches []*tailcfg.PeerChange,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
var err error
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
var removedIDs []tailcfg.NodeID
|
||||
var changedIDs []types.NodeID
|
||||
for nodeID, nodeChanged := range changed {
|
||||
if nodeChanged {
|
||||
if nodeID != node.ID() {
|
||||
changedIDs = append(changedIDs, nodeID)
|
||||
}
|
||||
} else {
|
||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||
}
|
||||
}
|
||||
changedNodes := types.Nodes{}
|
||||
if len(changedIDs) > 0 {
|
||||
changedNodes, err = m.ListNodes(changedIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
&resp,
|
||||
false, // partial change
|
||||
m.state,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
changedNodes.ViewSlice(),
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.PeersRemoved = removedIDs
|
||||
|
||||
// Sending patches as a part of a PeersChanged response
|
||||
// is technically not suppose to be done, but they are
|
||||
// applied after the PeersChanged. The patch list
|
||||
// should _only_ contain Nodes that are not in the
|
||||
// PeersChanged or PeersRemoved list and the caller
|
||||
// should filter them out.
|
||||
//
|
||||
// From tailcfg docs:
|
||||
// These are applied after Peers* above, but in practice the
|
||||
// control server should only send these on their own, without
|
||||
// the Peers* fields also set.
|
||||
if patches != nil {
|
||||
resp.PeersChangedPatch = patches
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
// Add the node itself, it might have changed, and particularly
|
||||
// if there are no patches or changes, this is a self update.
|
||||
tailnode, err := tailNode(
|
||||
node, mapRequest.Version, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
||||
func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
|
||||
// PeerChangedPatchResponse creates a patch MapResponse with
|
||||
// incoming update from a state change.
|
||||
func (m *Mapper) PeerChangedPatchResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
func (m *mapper) peerChangedPatchResponse(
|
||||
nodeID types.NodeID,
|
||||
changed []*tailcfg.PeerChange,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.PeersChangedPatch = changed
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) marshalMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
resp *tailcfg.MapResponse,
|
||||
node types.NodeView,
|
||||
compression string,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
atomic.AddUint64(&m.seq, 1)
|
||||
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapRequest": mapRequest,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
panic("invalid type in sync pool")
|
||||
}
|
||||
out := encoder.EncodeAll(in, nil)
|
||||
_ = encoder.Close()
|
||||
zstdEncoderPool.Put(encoder)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var zstdEncoderPool = &sync.Pool{
|
||||
New: func() any {
|
||||
encoder, err := smallzstd.NewEncoder(
|
||||
nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoder
|
||||
},
|
||||
}
|
||||
|
||||
// baseMapResponse returns a tailcfg.MapResponse with
|
||||
// KeepAlive false and ControlTime set to now.
|
||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
// TODO(kradalby): Implement PingRequest?
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
||||
// with the basic configuration from headscale set.
|
||||
// It is used in for bigger updates, such as full and lite, not
|
||||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node, capVer, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
// peerChangeResponse returns a MapResponse with changed or added nodes.
|
||||
func (m *mapper) peerChangeResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
resp.DERPMap = m.state.DERPMap()
|
||||
|
||||
resp.Domain = m.cfg.Domain()
|
||||
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
resp.CollectServices = "false"
|
||||
|
||||
resp.KeepAlive = false
|
||||
|
||||
resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !m.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
Build()
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
|
||||
func (m *mapper) peerRemovedResponse(
|
||||
nodeID types.NodeID,
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
nodeID types.NodeID,
|
||||
messages ...string,
|
||||
) {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s.json", now, responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.notif.IsLikelyConnected(peer.ID)
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
online := m.notif.IsLikelyConnected(node.ID)
|
||||
node.IsOnline = &online
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
||||
|
||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
||||
// necessary changes when peers have changed.
|
||||
func appendPeerChanges(
|
||||
resp *tailcfg.MapResponse,
|
||||
|
||||
fullChange bool,
|
||||
state *state.State,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changed views.Slice[types.NodeView],
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
filter, matchers := state.Filter()
|
||||
|
||||
sshPolicy, err := state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var reducedChanged views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
reducedChanged = policy.ReduceNodes(node, changed, matchers)
|
||||
} else {
|
||||
reducedChanged = changed
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(node, reducedChanged)
|
||||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
reducedChanged, capVer, state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
if fullChange {
|
||||
resp.Peers = tailPeers
|
||||
} else {
|
||||
resp.PeersChanged = tailPeers
|
||||
}
|
||||
resp.DNSConfig = dnsConfig
|
||||
resp.UserProfiles = profiles
|
||||
resp.SSHPolicy = sshPolicy
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1.View(),
|
||||
nodeInShared1,
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
@@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
for _, id := range peerIDs {
|
||||
if peer.ID == id {
|
||||
filtered = append(filtered, peer)
|
||||
break
|
||||
}
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
for _, id := range nodeIDs {
|
||||
if node.ID == id {
|
||||
filtered = append(filtered, node)
|
||||
break
|
||||
}
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
47
hscontrol/mapper/utils.go
Normal file
47
hscontrol/mapper/utils.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mapper
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
// mergePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
|
||||
ns.nodeKey = nv.NodeKey()
|
||||
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
||||
if !sess.isStreaming() {
|
||||
sess.serve()
|
||||
@@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err := json.Marshal(registerResponse)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write(respBody)
|
||||
// Ensure response is flushed to client
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||
// and validates that it matches the MachineKey from the Noise session.
|
||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||
nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
|
||||
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
return types.NodeView{}, err
|
||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||
}
|
||||
|
||||
nv := node.View()
|
||||
|
||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||
if ns.machineKey != nv.MachineKey() {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"tailscale.com/envknob"
|
||||
)
|
||||
|
||||
const prometheusNamespace = "headscale"
|
||||
|
||||
var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS")
|
||||
|
||||
var notifierUpdateSent *prometheus.CounterVec
|
||||
|
||||
func init() {
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type", "trigger", "id"})
|
||||
} else {
|
||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type", "trigger"})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_waiters_for_lock",
|
||||
Help: "gauge of waiters for the notifier lock",
|
||||
}, []string{"type", "action"})
|
||||
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_wait_for_lock_seconds",
|
||||
Help: "histogram of time spent waiting for the notifier lock",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
|
||||
}, []string{"action"})
|
||||
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_received_total",
|
||||
Help: "total count of updates received by notifier",
|
||||
}, []string{"type", "trigger"})
|
||||
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_open_channels_total",
|
||||
Help: "total count open channels in notifier",
|
||||
})
|
||||
notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_waiters_for_lock",
|
||||
Help: "gauge of waiters for the notifier batcher lock",
|
||||
}, []string{"type", "action"})
|
||||
notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_changes_pending",
|
||||
Help: "gauge of full changes pending in the notifier batcher",
|
||||
}, []string{})
|
||||
notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_patches_pending",
|
||||
Help: "gauge of patches pending in the notifier batcher",
|
||||
}, []string{})
|
||||
)
|
||||
@@ -1,488 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
)
|
||||
|
||||
func init() {
|
||||
deadlock.Opts.Disable = !debugDeadlock
|
||||
if debugDeadlock {
|
||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||
}
|
||||
}
|
||||
|
||||
type Notifier struct {
|
||||
l deadlock.Mutex
|
||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||
connected *xsync.MapOf[types.NodeID, bool]
|
||||
b *batcher
|
||||
cfg *types.Config
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n := &Notifier{
|
||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||
cfg: cfg,
|
||||
closed: false,
|
||||
}
|
||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
||||
n.b = b
|
||||
|
||||
go b.doWork()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Close stops the batcher and closes all channels.
|
||||
func (n *Notifier) Close() {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
|
||||
|
||||
n.closed = true
|
||||
n.b.close()
|
||||
|
||||
// Close channels safely using the helper method
|
||||
for nodeID, c := range n.nodes {
|
||||
n.safeCloseChannel(nodeID, c)
|
||||
}
|
||||
|
||||
// Clear node map after closing channels
|
||||
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed.
|
||||
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("recover", r).
|
||||
Msg("recovered from panic when closing channel in Close()")
|
||||
}
|
||||
}()
|
||||
close(c)
|
||||
}
|
||||
|
||||
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
||||
log.Trace().
|
||||
Uint64("node.id", nID.Uint64()).
|
||||
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// If a channel exists, it means the node has opened a new
|
||||
// connection. Close the old channel and replace it.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
n.tracef(nodeID, "channel present, closing and replacing")
|
||||
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
|
||||
// if/when someone is waiting to send on this channel
|
||||
go func(ch chan<- types.StateUpdate) {
|
||||
n.safeCloseChannel(nodeID, ch)
|
||||
}(curr)
|
||||
}
|
||||
|
||||
n.nodes[nodeID] = c
|
||||
n.connected.Store(nodeID, true)
|
||||
|
||||
n.tracef(nodeID, "added new channel")
|
||||
notifierNodeUpdateChans.Inc()
|
||||
}
|
||||
|
||||
// RemoveNode removes a node and a given channel from the notifier.
|
||||
// It checks that the channel is the same as currently being updated
|
||||
// and ignores the removal if it is not.
|
||||
// RemoveNode reports if the node/chan was removed.
|
||||
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
|
||||
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(n.nodes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If the channel exist, but it does not belong
|
||||
// to the caller, ignore.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
if curr != c {
|
||||
n.tracef(nodeID, "channel has been replaced, not removing")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
delete(n.nodes, nodeID)
|
||||
n.connected.Store(nodeID, false)
|
||||
|
||||
n.tracef(nodeID, "removed channel")
|
||||
notifierNodeUpdateChans.Dec()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsConnected reports if a node is connected to headscale and has a
|
||||
// poll session open.
|
||||
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
|
||||
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsLikelyConnected reports if a node is connected to headscale and has a
|
||||
// poll session open, but doesn't lock, so might be wrong.
|
||||
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes.
|
||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||
return n.connected
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
||||
n.NotifyWithIgnore(ctx, update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyWithIgnore(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
ignoreNodeIDs ...types.NodeID,
|
||||
) {
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
n.b.addOrPassthrough(update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyByNodeID(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
nodeID types.NodeID,
|
||||
) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
|
||||
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if c, ok := n.nodes[nodeID]; ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
||||
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
||||
Msgf("update not sent, context cancelled")
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) sendAll(update types.StateUpdate) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
|
||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for id, c := range n.nodes {
|
||||
// Whenever an update is sent to all nodes, there is a chance that the node
|
||||
// has disconnected and the goroutine that was supposed to consume the update
|
||||
// has shut down the channel and is waiting for the lock held here in RemoveNode.
|
||||
// This means that there is potential for a deadlock which would stop all updates
|
||||
// going out to clients. This timeout prevents that from happening by moving on to the
|
||||
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
|
||||
// call will succeed and the update will go to the correct nodes on the next call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Msgf("update not sent, context cancelled")
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) String() string {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
|
||||
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
|
||||
|
||||
var keys []types.NodeID
|
||||
n.connected.Range(func(key types.NodeID, value bool) bool {
|
||||
keys = append(keys, key)
|
||||
return true
|
||||
})
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keys[i] < keys[j]
|
||||
})
|
||||
|
||||
for _, key := range keys {
|
||||
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
|
||||
|
||||
for _, key := range keys {
|
||||
val, _ := n.connected.Load(key)
|
||||
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type batcher struct {
|
||||
tick *time.Ticker
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cancelCh chan struct{}
|
||||
|
||||
changedNodeIDs set.Slice[types.NodeID]
|
||||
nodesChanged bool
|
||||
patches map[types.NodeID]tailcfg.PeerChange
|
||||
patchesChanged bool
|
||||
|
||||
n *Notifier
|
||||
}
|
||||
|
||||
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
||||
return &batcher{
|
||||
tick: time.NewTicker(batchTime),
|
||||
cancelCh: make(chan struct{}),
|
||||
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
||||
n: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) close() {
|
||||
b.cancelCh <- struct{}{}
|
||||
}
|
||||
|
||||
// addOrPassthrough adds the update to the batcher, if it is not a
|
||||
// type that is currently batched, it will be sent immediately.
|
||||
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
|
||||
switch update.Type {
|
||||
case types.StatePeerChanged:
|
||||
b.changedNodeIDs.Add(update.ChangeNodes...)
|
||||
b.nodesChanged = true
|
||||
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
for _, newPatch := range update.ChangePatches {
|
||||
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
||||
overwritePatch(&curr, newPatch)
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
||||
} else {
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
||||
}
|
||||
}
|
||||
b.patchesChanged = true
|
||||
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
|
||||
|
||||
default:
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
}
|
||||
|
||||
// flush sends all the accumulated patches to all
|
||||
// nodes in the notifier.
|
||||
func (b *batcher) flush() {
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
|
||||
|
||||
if b.nodesChanged || b.patchesChanged {
|
||||
var patches []*tailcfg.PeerChange
|
||||
// If a node is getting a full update from a change
|
||||
// node update, then the patch can be dropped.
|
||||
for nodeID, patch := range b.patches {
|
||||
if b.changedNodeIDs.Contains(nodeID) {
|
||||
delete(b.patches, nodeID)
|
||||
} else {
|
||||
patches = append(patches, &patch)
|
||||
}
|
||||
}
|
||||
|
||||
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
||||
sort.Slice(changedNodes, func(i, j int) bool {
|
||||
return changedNodes[i] < changedNodes[j]
|
||||
})
|
||||
|
||||
if b.changedNodeIDs.Slice().Len() > 0 {
|
||||
update := types.UpdatePeerChanged(changedNodes...)
|
||||
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
|
||||
if len(patches) > 0 {
|
||||
patchUpdate := types.UpdatePeerPatch(patches...)
|
||||
|
||||
b.n.sendAll(patchUpdate)
|
||||
}
|
||||
|
||||
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
||||
notifierBatcherChanges.WithLabelValues().Set(0)
|
||||
b.nodesChanged = false
|
||||
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
||||
notifierBatcherPatches.WithLabelValues().Set(0)
|
||||
b.patchesChanged = false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) doWork() {
|
||||
for {
|
||||
select {
|
||||
case <-b.cancelCh:
|
||||
return
|
||||
case <-b.tick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// overwritePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
||||
@@ -1,342 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestBatcher(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
updates []types.StateUpdate
|
||||
want []types.StateUpdate
|
||||
}{
|
||||
{
|
||||
name: "full-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "derp-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3, 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-patch-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-same-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-multiple-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := NewNotifier(&types.Config{
|
||||
Tuning: types.Tuning{
|
||||
// We will call flush manually for the tests,
|
||||
// so do not run the worker.
|
||||
BatchChangeDelay: time.Hour,
|
||||
|
||||
// Since we do not load the config, we won't get the
|
||||
// default, so set it manually so we dont time out
|
||||
// and have flakes.
|
||||
NotifierSendTimeout: time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
ch := make(chan types.StateUpdate, 30)
|
||||
defer close(ch)
|
||||
n.AddNode(1, ch)
|
||||
defer n.RemoveNode(1, ch)
|
||||
|
||||
for _, u := range tt.updates {
|
||||
n.NotifyAll(t.Context(), u)
|
||||
}
|
||||
|
||||
n.b.flush()
|
||||
|
||||
var got []types.StateUpdate
|
||||
for len(ch) > 0 {
|
||||
out := <-ch
|
||||
got = append(got, out)
|
||||
}
|
||||
|
||||
// Make the inner order stable for comparison.
|
||||
for _, u := range got {
|
||||
slices.Sort(u.ChangeNodes)
|
||||
sort.Slice(u.ChangePatches, func(i, j int) bool {
|
||||
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
|
||||
})
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
||||
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
||||
// close a channel that was already closed, which can happen when a node changes
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
|
||||
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
// mock config for the notifier
|
||||
cfg := &types.Config{
|
||||
Tuning: types.Tuning{
|
||||
NotifierSendTimeout: 1 * time.Second,
|
||||
BatchChangeDelay: 1 * time.Second,
|
||||
NodeMapSessionBufferedChanSize: 30,
|
||||
},
|
||||
}
|
||||
|
||||
notifier := NewNotifier(cfg)
|
||||
defer notifier.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
updateChan := make(chan types.StateUpdate, 10)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Number of goroutines to spawn for concurrent access
|
||||
concurrentAccessors := 100
|
||||
iterations := 100
|
||||
|
||||
// Add node to notifier
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
|
||||
// Track errors
|
||||
errChan := make(chan string, concurrentAccessors*iterations)
|
||||
|
||||
// Start goroutines to cause a race
|
||||
wg.Add(concurrentAccessors)
|
||||
for i := range concurrentAccessors {
|
||||
go func(routineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for range iterations {
|
||||
// Simulate race by having some goroutines check IsLikelyConnected
|
||||
// while others add/remove the node
|
||||
switch routineID % 3 {
|
||||
case 0:
|
||||
// This goroutine checks connection status
|
||||
isConnected := notifier.IsLikelyConnected(nodeID)
|
||||
if isConnected != true && isConnected != false {
|
||||
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
||||
}
|
||||
case 1:
|
||||
// This goroutine removes the node
|
||||
notifier.RemoveNode(nodeID, updateChan)
|
||||
default:
|
||||
// This goroutine adds the node back
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
}
|
||||
|
||||
// Small random delay to increase chance of races
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Collate errors
|
||||
var errors []string
|
||||
for err := range errChan {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
|
||||
}
|
||||
}
|
||||
@@ -16,9 +16,8 @@ import (
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -56,11 +55,10 @@ type RegistrationInfo struct {
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
state *state.State
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
notifier *notifier.Notifier
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
@@ -68,10 +66,9 @@ type AuthProviderOIDC struct {
|
||||
|
||||
func NewAuthProviderOIDC(
|
||||
ctx context.Context,
|
||||
h *Headscale,
|
||||
serverURL string,
|
||||
cfg *types.OIDCConfig,
|
||||
state *state.State,
|
||||
notif *notifier.Notifier,
|
||||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
@@ -94,11 +91,10 @@ func NewAuthProviderOIDC(
|
||||
)
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
state: state,
|
||||
registrationCache: registrationCache,
|
||||
notifier: notif,
|
||||
|
||||
oidcProvider: oidcProvider,
|
||||
oauth2Config: oauth2Config,
|
||||
@@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
|
||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
||||
a.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this comment right?
|
||||
@@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func extractCodeAndStateParamFromRequest(
|
||||
@@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
var err error
|
||||
var newUser bool
|
||||
var policyChanged bool
|
||||
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
// if the user is still not found, create a new empty user.
|
||||
// TODO(kradalby): This might cause us to not have an ID below which
|
||||
// is a problem.
|
||||
if user == nil {
|
||||
newUser = true
|
||||
user = &types.User{}
|
||||
@@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
user.FromClaim(claims)
|
||||
|
||||
if newUser {
|
||||
user, policyChanged, err = a.state.CreateUser(*user)
|
||||
user, policyChanged, err = a.h.state.CreateUser(*user)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
*u = *user
|
||||
return nil
|
||||
})
|
||||
@@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
registrationID types.RegistrationID,
|
||||
expiry time.Time,
|
||||
) (bool, error) {
|
||||
node, newNode, err := a.state.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
types.UserID(user.ID),
|
||||
&expiry,
|
||||
@@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := a.state.AutoApproveRoutes(node)
|
||||
_, policyChanged, err := a.state.SaveNode(node)
|
||||
_ = a.h.state.AutoApproveRoutes(node)
|
||||
_, policyChange, err := a.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
|
||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
||||
// Policy updates are full and take precedence over node changes.
|
||||
if !policyChange.Empty() {
|
||||
a.h.Change(policyChange)
|
||||
} else {
|
||||
a.h.Change(nodeChange)
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||
a.notifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID,
|
||||
)
|
||||
|
||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
}
|
||||
|
||||
return newNode, nil
|
||||
return !nodeChange.Empty(), nil
|
||||
}
|
||||
|
||||
// TODO(kradalby):
|
||||
|
||||
@@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check approved subnet routes - nodes should have access
|
||||
// to subnets they're approved to route traffic for.
|
||||
subnetRoutes := node.SubnetRoutes()
|
||||
|
||||
for _, subnetRoute := range subnetRoutes {
|
||||
if expanded.OverlapsPrefix(subnetRoute) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
@@ -142,16 +153,23 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
|
||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
||||
// This prevents clearing existing approved routes when nodes
|
||||
// temporarily don't have announced routes during policy changes.
|
||||
if len(newApproved) > 0 {
|
||||
combined := append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(combined)
|
||||
combined = slices.Compact(combined)
|
||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
node.ApprovedRoutes = newApproved
|
||||
|
||||
return true
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
||||
node.ApprovedRoutes = combined
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules(
|
||||
}
|
||||
|
||||
if ips == nil {
|
||||
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
prefixes := ips.Prefixes()
|
||||
|
||||
for _, pref := range prefixes {
|
||||
for _, port := range dest.Ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: pref.String(),
|
||||
@@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
@@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy(
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
continue // Skip this rule if we can't resolve sources
|
||||
}
|
||||
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
|
||||
@@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
defer clear(pm.sshPolicyMap)
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
|
||||
@@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
|
||||
// In addition to unmarshalling, it will also validate the policy.
|
||||
// This is the only entrypoint of reading a policy from a file or other source.
|
||||
func unmarshalPolicy(b []byte) (*Policy, error) {
|
||||
if b == nil || len(b) == 0 {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -412,7 +412,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
`,
|
||||
wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`,
|
||||
},
|
||||
// TODO(kradalby): Figure out why this doesnt work.
|
||||
// TODO(kradalby): Figure out why this doesn't work.
|
||||
// {
|
||||
// name: "invalid-hostname",
|
||||
// input: `
|
||||
@@ -1074,7 +1074,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
ForcedTags: []string{"tag:anything"},
|
||||
IPv4: ap("100.100.101.2"),
|
||||
},
|
||||
// not matchin pak tag
|
||||
// not matching pak tag
|
||||
{
|
||||
User: users["testuser"],
|
||||
AuthKey: &types.PreAuthKey{
|
||||
@@ -1108,7 +1108,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
ForcedTags: []string{"tag:anything"},
|
||||
IPv4: ap("100.100.101.5"),
|
||||
},
|
||||
// not matchin pak tag
|
||||
// not matching pak tag
|
||||
{
|
||||
User: users["groupuser"],
|
||||
AuthKey: &types.PreAuthKey{
|
||||
@@ -1147,7 +1147,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
ForcedTags: []string{"tag:anything"},
|
||||
IPv4: ap("100.100.101.10"),
|
||||
},
|
||||
// not matchin pak tag
|
||||
// not matching pak tag
|
||||
{
|
||||
AuthKey: &types.PreAuthKey{
|
||||
Tags: []string{"tag:alsotagged"},
|
||||
@@ -1159,7 +1159,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
ForcedTags: []string{"tag:test"},
|
||||
IPv4: ap("100.100.101.234"),
|
||||
},
|
||||
// not matchin pak tag
|
||||
// not matching pak tag
|
||||
{
|
||||
AuthKey: &types.PreAuthKey{
|
||||
Tags: []string{"tag:test"},
|
||||
|
||||
@@ -2,20 +2,20 @@ package hscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/zstdframe"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,18 +31,17 @@ type mapSession struct {
|
||||
req tailcfg.MapRequest
|
||||
ctx context.Context
|
||||
capVer tailcfg.CapabilityVersion
|
||||
mapper *mapper.Mapper
|
||||
|
||||
cancelChMu deadlock.Mutex
|
||||
|
||||
ch chan types.StateUpdate
|
||||
ch chan *tailcfg.MapResponse
|
||||
cancelCh chan struct{}
|
||||
cancelChOpen bool
|
||||
|
||||
keepAlive time.Duration
|
||||
keepAliveTicker *time.Ticker
|
||||
|
||||
node types.NodeView
|
||||
node *types.Node
|
||||
w http.ResponseWriter
|
||||
|
||||
warnf func(string, ...any)
|
||||
@@ -55,18 +54,9 @@ func (h *Headscale) newMapSession(
|
||||
ctx context.Context,
|
||||
req tailcfg.MapRequest,
|
||||
w http.ResponseWriter,
|
||||
nv types.NodeView,
|
||||
node *types.Node,
|
||||
) *mapSession {
|
||||
warnf, infof, tracef, errf := logPollFuncView(req, nv)
|
||||
|
||||
var updateChan chan types.StateUpdate
|
||||
if req.Stream {
|
||||
// Use a buffered channel in case a node is not fully ready
|
||||
// to receive a message to make sure we dont block the entire
|
||||
// notifier.
|
||||
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
|
||||
updateChan <- types.UpdateFull()
|
||||
}
|
||||
warnf, infof, tracef, errf := logPollFunc(req, node)
|
||||
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
||||
|
||||
@@ -75,11 +65,10 @@ func (h *Headscale) newMapSession(
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
w: w,
|
||||
node: nv,
|
||||
node: node,
|
||||
capVer: req.Version,
|
||||
mapper: h.mapper,
|
||||
|
||||
ch: updateChan,
|
||||
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
|
||||
cancelCh: make(chan struct{}),
|
||||
cancelChOpen: true,
|
||||
|
||||
@@ -95,15 +84,11 @@ func (h *Headscale) newMapSession(
|
||||
}
|
||||
|
||||
func (m *mapSession) isStreaming() bool {
|
||||
return m.req.Stream && !m.req.ReadOnly
|
||||
return m.req.Stream
|
||||
}
|
||||
|
||||
func (m *mapSession) isEndpointUpdate() bool {
|
||||
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
|
||||
}
|
||||
|
||||
func (m *mapSession) isReadOnlyUpdate() bool {
|
||||
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
|
||||
return !m.req.Stream && m.req.OmitPeers
|
||||
}
|
||||
|
||||
func (m *mapSession) resetKeepAlive() {
|
||||
@@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() {
|
||||
|
||||
func (m *mapSession) beforeServeLongPoll() {
|
||||
if m.node.IsEphemeral() {
|
||||
m.h.ephemeralGC.Cancel(m.node.ID())
|
||||
m.h.ephemeralGC.Cancel(m.node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) afterServeLongPoll() {
|
||||
if m.node.IsEphemeral() {
|
||||
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// serve handles non-streaming requests.
|
||||
func (m *mapSession) serve() {
|
||||
// TODO(kradalby): A set todos to harden:
|
||||
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
|
||||
|
||||
// This is the mechanism where the node gives us information about its
|
||||
// current configuration.
|
||||
//
|
||||
// If OmitPeers is true, Stream is false, and ReadOnly is false,
|
||||
// If OmitPeers is true and Stream is false
|
||||
// then the server will let clients update their endpoints without
|
||||
// breaking existing long-polling (Stream == true) connections.
|
||||
// In this case, the server can omit the entire response; the client
|
||||
@@ -138,26 +120,18 @@ func (m *mapSession) serve() {
|
||||
//
|
||||
// This is what Tailscale calls a Lite update, the client ignores
|
||||
// the response and just wants a 200.
|
||||
// !req.stream && !req.ReadOnly && req.OmitPeers
|
||||
//
|
||||
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
|
||||
// !req.stream && req.OmitPeers
|
||||
if m.isEndpointUpdate() {
|
||||
m.handleEndpointUpdate()
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
m.h.Change(c)
|
||||
|
||||
// ReadOnly is whether the client just wants to fetch the
|
||||
// MapResponse, without updating their Endpoints. The
|
||||
// Endpoints field will be ignored and LastSeen will not be
|
||||
// updated and peers will not be notified of changes.
|
||||
//
|
||||
// The intended use is for clients to discover the DERP map at
|
||||
// start-up before their first real endpoint update.
|
||||
if m.isReadOnlyUpdate() {
|
||||
m.handleReadOnlyRequest()
|
||||
|
||||
return
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() {
|
||||
close(m.cancelCh)
|
||||
m.cancelChMu.Unlock()
|
||||
|
||||
// only update node status if the node channel was removed.
|
||||
// in principal, it will be removed, but the client rapidly
|
||||
// reconnects, the channel might be of another connection.
|
||||
// In that case, it is not closed and the node is still online.
|
||||
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
change, err := m.h.state.Disconnect(m.node.ID())
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
|
||||
}
|
||||
|
||||
if change {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
}
|
||||
m.h.Change(disconnectChange)
|
||||
|
||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
||||
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
@@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() {
|
||||
m.h.pollNetMapStreamWG.Add(1)
|
||||
defer m.h.pollNetMapStreamWG.Done()
|
||||
|
||||
m.h.state.Connect(m.node.ID())
|
||||
|
||||
// Upgrade the writer to a ResponseController
|
||||
rc := http.NewResponseController(m.w)
|
||||
|
||||
// Longpolling will break if there is a write timeout,
|
||||
// so it needs to be disabled.
|
||||
rc.SetWriteDeadline(time.Time{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
|
||||
defer cancel()
|
||||
|
||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||
|
||||
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
|
||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
||||
// where the change is sent before the node is in the batcher's node map
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
||||
select {
|
||||
case m.ch <- &tailcfg.MapResponse{}:
|
||||
default:
|
||||
// Channel might be closed
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
||||
// but we still need to update routes and other state-level changes
|
||||
connectChange := m.h.state.Connect(m.node)
|
||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
||||
m.h.Change(connectChange)
|
||||
}
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
@@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
// Consume updates sent to node
|
||||
case update, ok := <-m.ch:
|
||||
m.tracef("received update from channel, ok: %t", ok)
|
||||
if !ok {
|
||||
m.tracef("update channel closed, streaming session is likely being replaced")
|
||||
return
|
||||
}
|
||||
|
||||
// If the node has been removed from headscale, close the stream
|
||||
if slices.Contains(update.Removed, m.node.ID()) {
|
||||
m.tracef("node removed, closing stream")
|
||||
if err := m.writeMap(update); err != nil {
|
||||
m.errf(err, "cannot write update to client")
|
||||
return
|
||||
}
|
||||
|
||||
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
|
||||
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
var lastMessage string
|
||||
|
||||
// Ensure the node view is updated, for example, there
|
||||
// might have been a hostinfo update in a sidechannel
|
||||
// which contains data needed to generate a map response.
|
||||
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
|
||||
if err != nil {
|
||||
m.errf(err, "Could not get machine from db")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
updateType := "full"
|
||||
switch update.Type {
|
||||
case types.StateFullUpdate:
|
||||
m.tracef("Sending Full MapResponse")
|
||||
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||
case types.StatePeerChanged:
|
||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
||||
|
||||
for _, nodeID := range update.ChangeNodes {
|
||||
changed[nodeID] = true
|
||||
}
|
||||
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "change"
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
||||
updateType = "patch"
|
||||
case types.StatePeerRemoved:
|
||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
||||
|
||||
for _, nodeID := range update.Removed {
|
||||
changed[nodeID] = false
|
||||
}
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateSelfUpdate:
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
// create the map so an empty (self) update is sent
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateDERPUpdated:
|
||||
m.tracef("Sending DERPUpdate MapResponse")
|
||||
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
|
||||
updateType = "derp"
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
m.errf(err, "Could not get the create map update")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Only send update if there is change
|
||||
if data != nil {
|
||||
startWrite := time.Now()
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
|
||||
return
|
||||
}
|
||||
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", updateType).Inc()
|
||||
m.tracef("update sent")
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
m.tracef("update sent")
|
||||
m.resetKeepAlive()
|
||||
|
||||
case <-m.keepAliveTicker.C:
|
||||
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error generating the keep alive msg")
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
return
|
||||
}
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
m.errf(err, "Cannot write keep alive message")
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
return
|
||||
}
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
if err := m.writeMap(&keepAlive); err != nil {
|
||||
m.errf(err, "cannot write keep alive")
|
||||
return
|
||||
}
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) handleEndpointUpdate() {
|
||||
m.tracef("received endpoint update")
|
||||
|
||||
// Get fresh node state from database for accurate route calculations
|
||||
node, err := m.h.state.GetNodeByID(m.node.ID())
|
||||
// writeMap writes the map response to the client.
|
||||
// It handles compression if requested and any headers that need to be set.
|
||||
// It also handles flushing the response if the ResponseWriter
|
||||
// implements http.Flusher.
|
||||
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
jsonBody, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to get fresh node from database for endpoint update")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
return
|
||||
return fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
change := m.node.PeerChangeFromMapRequest(m.req)
|
||||
|
||||
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
|
||||
change.Online = &online
|
||||
|
||||
node.ApplyPeerChange(&change)
|
||||
|
||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
|
||||
|
||||
// The node might not set NetInfo if it has not changed and if
|
||||
// the full HostInfo object is overwritten, the information is lost.
|
||||
// If there is no NetInfo, keep the previous one.
|
||||
// From 1.66 the client only sends it if changed:
|
||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||
// before we take the changes.
|
||||
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||
}
|
||||
node.Hostinfo = m.req.Hostinfo
|
||||
|
||||
logTracePeerChange(node.Hostname, sendUpdate, &change)
|
||||
|
||||
// If there is no changes and nothing to save,
|
||||
// return early.
|
||||
if peerChangeEmpty(change) && !sendUpdate {
|
||||
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||
return
|
||||
if m.req.Compress == util.ZstdCompression {
|
||||
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||
}
|
||||
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||
data = append(data, jsonBody...)
|
||||
|
||||
// Always update routes for connected nodes to handle reconnection scenarios
|
||||
// where routes need to be restored to the primary routes system
|
||||
routesToSet := node.SubnetRoutes()
|
||||
startWrite := time.Now()
|
||||
|
||||
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else if routesChanged {
|
||||
// Only send peer changed notification if routes actually changed
|
||||
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
// Send an update to the node itself with to ensure it
|
||||
// has an updated packetfilter allowing the new route
|
||||
// if it is defined in the ACL.
|
||||
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If routes were auto-approved, we need to save the node to persist the changes
|
||||
if routesAutoApproved {
|
||||
if _, _, err := m.h.state.SaveNode(node); err != nil {
|
||||
m.errf(err, "Failed to save auto-approved routes to node")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
return
|
||||
if m.isStreaming() {
|
||||
if f, ok := m.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
} else {
|
||||
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
// in the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
|
||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||
|
||||
_, policyChanged, err := m.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to persist/update node in the database")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(
|
||||
ctx,
|
||||
types.UpdatePeerChanged(node.ID),
|
||||
node.ID,
|
||||
)
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mapSession) handleReadOnlyRequest() {
|
||||
m.tracef("Client asked for a lite update, responding without peers")
|
||||
|
||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to create MapResponse")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
_, err = m.w.Write(mapResp)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to write response")
|
||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseReadOnly.WithLabelValues("ok").Inc()
|
||||
var keepAlive = tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
|
||||
if change.Key != nil {
|
||||
trace = trace.Str("node_key", change.Key.ShortString())
|
||||
if peerChange.Key != nil {
|
||||
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
||||
}
|
||||
|
||||
if change.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
|
||||
if peerChange.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
||||
}
|
||||
|
||||
if change.Online != nil {
|
||||
trace = trace.Bool("online", *change.Online)
|
||||
if peerChange.Online != nil {
|
||||
trace = trace.Bool("online", *peerChange.Online)
|
||||
}
|
||||
|
||||
if change.Endpoints != nil {
|
||||
eps := make([]string, len(change.Endpoints))
|
||||
for idx, ep := range change.Endpoints {
|
||||
if peerChange.Endpoints != nil {
|
||||
eps := make([]string, len(peerChange.Endpoints))
|
||||
for idx, ep := range peerChange.Endpoints {
|
||||
eps[idx] = ep.String()
|
||||
}
|
||||
|
||||
@@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe
|
||||
trace = trace.Bool("hostinfo_changed", hostinfoChange)
|
||||
}
|
||||
|
||||
if change.DERPRegion != 0 {
|
||||
trace = trace.Int("derp_region", change.DERPRegion)
|
||||
if peerChange.DERPRegion != 0 {
|
||||
trace = trace.Int("derp_region", peerChange.DERPRegion)
|
||||
}
|
||||
|
||||
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
|
||||
}
|
||||
|
||||
func peerChangeEmpty(chng tailcfg.PeerChange) bool {
|
||||
return chng.Key == nil &&
|
||||
chng.DiscoKey == nil &&
|
||||
chng.Online == nil &&
|
||||
chng.Endpoints == nil &&
|
||||
chng.DERPRegion == 0 &&
|
||||
chng.LastSeen == nil &&
|
||||
chng.KeyExpiry == nil
|
||||
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
|
||||
}
|
||||
|
||||
func logPollFunc(
|
||||
@@ -554,7 +323,6 @@ func logPollFunc(
|
||||
return func(msg string, a ...any) {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -564,7 +332,6 @@ func logPollFunc(
|
||||
func(msg string, a ...any) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -574,7 +341,6 @@ func logPollFunc(
|
||||
func(msg string, a ...any) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -584,7 +350,6 @@ func logPollFunc(
|
||||
func(err error, msg string, a ...any) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -593,91 +358,3 @@ func logPollFunc(
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
}
|
||||
|
||||
func logPollFuncView(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
nodeView types.NodeView,
|
||||
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
|
||||
return func(msg string, a ...any) {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(err error, msg string, a ...any) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Err(err).
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
}
|
||||
|
||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||
// - first bool reports if an update needs to be sent to nodes
|
||||
// - second reports if there has been changes to routes
|
||||
// the caller can then use this info to save and update nodes
|
||||
// and routes as needed.
|
||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||
if old.Equal(new) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if old == nil && new != nil {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Routes
|
||||
oldRoutes := make([]netip.Prefix, 0)
|
||||
if old != nil {
|
||||
oldRoutes = old.RoutableIPs
|
||||
}
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
// except for peerapi, which is how nodes talk to each other.
|
||||
// If peerapi was not part of the initial mapresponse, we
|
||||
// need to make sure its sent out later as it is needed for
|
||||
// Taildrop.
|
||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||
if len(old.Services) != len(new.Services) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func New() *PrimaryRoutes {
|
||||
// updatePrimaryLocked recalculates the primary routes and updates the internal state.
|
||||
// It returns true if the primary routes have changed.
|
||||
// It is assumed that the caller holds the lock.
|
||||
// The algorthm is as follows:
|
||||
// The algorithm is as follows:
|
||||
// 1. Reset the primaries map.
|
||||
// 2. Iterate over the routes and count the number of times a prefix is advertised.
|
||||
// 3. If a prefix is advertised by at least two nodes, it is a primary route.
|
||||
|
||||
@@ -17,10 +17,13 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
@@ -46,12 +49,6 @@ type State struct {
|
||||
// cfg holds the current Headscale configuration
|
||||
cfg *types.Config
|
||||
|
||||
// in-memory data, protected by mu
|
||||
// nodes contains the current set of registered nodes
|
||||
nodes types.Nodes
|
||||
// users contains the current set of users/namespaces
|
||||
users types.Users
|
||||
|
||||
// subsystem keeping state
|
||||
// db provides persistent storage and database operations
|
||||
db *hsdb.HSDatabase
|
||||
@@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
return &State{
|
||||
cfg: cfg,
|
||||
|
||||
nodes: nodes,
|
||||
users: users,
|
||||
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
// TODO(kradalby): Update DERPMap
|
||||
@@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
@@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
}
|
||||
|
||||
// Even if the policy manager doesn't detect a filter change, SSH policies
|
||||
// might now be resolvable when they weren't before. If there are existing
|
||||
// nodes, we should send a policy change to ensure they get updated SSH policies.
|
||||
if !policyChanged {
|
||||
nodes, err := s.ListNodes()
|
||||
if err == nil && len(nodes) > 0 {
|
||||
policyChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
|
||||
|
||||
// TODO(kradalby): implement the user in-memory cache
|
||||
|
||||
return &user, policyChanged, nil
|
||||
@@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
|
||||
}
|
||||
|
||||
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
var c change.ChangeSet
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
// Basic node change without specific details since this is a generic update
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return node, c, nil
|
||||
}
|
||||
|
||||
// SaveNode persists an existing node to the database and updates the policy manager.
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.db.DB.Save(node).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("saving node: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
if policyChanged {
|
||||
return node, change.PolicyChange(), nil
|
||||
}
|
||||
|
||||
return node, change.EmptySet, nil
|
||||
}
|
||||
|
||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||
// Returns whether policies changed and any error. This operation is irreversible.
|
||||
func (s *State) DeleteNode(node *types.Node) (bool, error) {
|
||||
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
|
||||
err := s.db.DeleteNode(node)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
c := change.NodeRemoved(node.ID)
|
||||
|
||||
// Check if policy manager needs updating after node deletion
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
}
|
||||
|
||||
return policyChanged, nil
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *State) Connect(id types.NodeID) {
|
||||
func (s *State) Connect(node *types.Node) change.ChangeSet {
|
||||
c := change.NodeOnline(node.ID)
|
||||
routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
if routeChange {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *State) Disconnect(id types.NodeID) (bool, error) {
|
||||
// TODO(kradalby): This node should update the in memory state
|
||||
_, polChanged, err := s.SetLastSeen(id, time.Now())
|
||||
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
|
||||
c := change.NodeOffline(node.ID)
|
||||
|
||||
_, _, err := s.SetLastSeen(node.ID, time.Now())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("disconnecting node: %w", err)
|
||||
return c, fmt.Errorf("disconnecting node: %w", err)
|
||||
}
|
||||
|
||||
changed := s.primaryRoutes.SetRoutes(id)
|
||||
if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
// TODO(kradalby): the returned change should be more nuanced allowing us to
|
||||
// send more directed updates.
|
||||
return changed || polChanged, nil
|
||||
// TODO(kradalby): This node should update the in memory state
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetNodeByID retrieves a node by ID.
|
||||
@@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) {
|
||||
}
|
||||
|
||||
// SetNodeExpiry updates the expiration time for a node.
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.KeyExpiry(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetNodeTags assigns tags to a node for use in access control policies.
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetTags(tx, nodeID, tags)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||
}
|
||||
|
||||
// Update primary routes after changing approved routes
|
||||
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
|
||||
|
||||
if routeChange || !c.IsFull() {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// RenameNode changes the display name of a node.
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.RenameNode(tx, nodeID, newName)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
|
||||
})
|
||||
}
|
||||
|
||||
// AssignNodeToUser transfers a node to a different user.
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
||||
@@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
|
||||
// ExpireExpiredNodes finds and processes expired nodes since the last check.
|
||||
// Returns next check time, state update with expired nodes, and whether any were found.
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
|
||||
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
|
||||
}
|
||||
|
||||
@@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
|
||||
}
|
||||
|
||||
// SetNodeRoutes sets the primary routes for a node.
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
|
||||
return s.primaryRoutes.SetRoutes(nodeID, routes...)
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
|
||||
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
|
||||
// Route changes affect packet filters for all nodes, so trigger a policy change
|
||||
// to ensure filters are regenerated across the entire network
|
||||
return change.PolicyChange()
|
||||
}
|
||||
|
||||
return change.EmptySet
|
||||
}
|
||||
|
||||
// GetNodePrimaryRoutes returns the primary routes for a node.
|
||||
@@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
return s.db.HandleNodeFromAuthPath(
|
||||
@@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) HandleNodeFromPreAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, change.ChangeSet, bool, error) {
|
||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, err
|
||||
}
|
||||
|
||||
err = pak.Validate()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, false, err
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
@@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
AuthKeyID: &pak.ID,
|
||||
}
|
||||
|
||||
// For auth key registration, ensure we don't keep an expired node
|
||||
// This is especially important for re-registration after logout
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
} else if !regReq.Expiry.IsZero() {
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// don't set expiry so the node won't be considered expired
|
||||
log.Debug().
|
||||
Time("requested_expiry", regReq.Expiry).
|
||||
Str("node", regReq.Hostinfo.Hostname).
|
||||
Msg("Ignoring expired expiry time from auth key registration")
|
||||
}
|
||||
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
}
|
||||
|
||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("writing node to database: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is a logout request for an ephemeral node
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||
// This is a logout request for an ephemeral node, delete it immediately
|
||||
c, err := s.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
}
|
||||
return nil, c, false, nil
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
// This is necessary because we just created a new node.
|
||||
// We need to ensure that the policy manager is aware of this new node.
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
// Also update users to ensure all users are known when evaluating policies.
|
||||
usersChanged, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||
}
|
||||
|
||||
return node, policyChanged, nil
|
||||
nodesChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
}
|
||||
|
||||
policyChanged := usersChanged || nodesChanged
|
||||
|
||||
c := change.NodeAdded(node.ID)
|
||||
|
||||
return node, c, policyChanged, nil
|
||||
}
|
||||
|
||||
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
|
||||
@@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
|
||||
return false, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("updating policy manager users: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
@@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): This should just take the node ID?
|
||||
func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) {
|
||||
// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes,
|
||||
// which means we could shortcut the whole change thing if there are no other important updates.
|
||||
peerChange := node.PeerChangeFromMapRequest(req)
|
||||
|
||||
node.ApplyPeerChange(&peerChange)
|
||||
|
||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo)
|
||||
|
||||
// The node might not set NetInfo if it has not changed and if
|
||||
// the full HostInfo object is overwritten, the information is lost.
|
||||
// If there is no NetInfo, keep the previous one.
|
||||
// From 1.66 the client only sends it if changed:
|
||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||
// before we take the changes.
|
||||
if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||
req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||
}
|
||||
node.Hostinfo = req.Hostinfo
|
||||
|
||||
// If there is no changes and nothing to save,
|
||||
// return early.
|
||||
if peerChangeEmpty(peerChange) && !sendUpdate {
|
||||
// mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||
return change.EmptySet, nil
|
||||
}
|
||||
|
||||
c := change.EmptySet
|
||||
|
||||
// Check if the Hostinfo of the node has changed.
|
||||
// If it has changed, check if there has been a change to
|
||||
// the routable IPs of the host and update them in
|
||||
// the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the route change.
|
||||
// If the hostinfo has changed, but not the routes, just update
|
||||
// hostinfo and let the function continue.
|
||||
if routesChanged {
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
_ = s.AutoApproveRoutes(node)
|
||||
|
||||
// Update the routes of the given node in the route manager to
|
||||
// see if an update needs to be sent.
|
||||
c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
// in the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(req.Hostinfo)
|
||||
|
||||
_, policyChange, err := s.SaveNode(node)
|
||||
if err != nil {
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
if policyChange.IsFull() {
|
||||
c = policyChange
|
||||
}
|
||||
|
||||
if c.Empty() {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||
// - first bool reports if an update needs to be sent to nodes
|
||||
// - second reports if there has been changes to routes
|
||||
// the caller can then use this info to save and update nodes
|
||||
// and routes as needed.
|
||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||
if old.Equal(new) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if old == nil && new != nil {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Routes
|
||||
oldRoutes := make([]netip.Prefix, 0)
|
||||
if old != nil {
|
||||
oldRoutes = old.RoutableIPs
|
||||
}
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
// except for peerapi, which is how nodes talk to each other.
|
||||
// If peerapi was not part of the initial mapresponse, we
|
||||
// need to make sure its sent out later as it is needed for
|
||||
// Taildrop.
|
||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||
if len(old.Services) != len(new.Services) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
|
||||
return peerChange.Key == nil &&
|
||||
peerChange.DiscoKey == nil &&
|
||||
peerChange.Online == nil &&
|
||||
peerChange.Endpoints == nil &&
|
||||
peerChange.DERPRegion == 0 &&
|
||||
peerChange.LastSeen == nil &&
|
||||
peerChange.KeyExpiry == nil
|
||||
}
|
||||
|
||||
183
hscontrol/types/change/change.go
Normal file
183
hscontrol/types/change/change.go
Normal file
@@ -0,0 +1,183 @@
|
||||
//go:generate go tool stringer -type=Change
|
||||
package change
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
)
|
||||
|
||||
type (
|
||||
NodeID = types.NodeID
|
||||
UserID = types.UserID
|
||||
)
|
||||
|
||||
type Change int
|
||||
|
||||
const (
|
||||
ChangeUnknown Change = 0
|
||||
|
||||
// Deprecated: Use specific change instead
|
||||
// Full is a legacy change to ensure places where we
|
||||
// have not yet determined the specific update, can send.
|
||||
Full Change = 9
|
||||
|
||||
// Server changes.
|
||||
Policy Change = 11
|
||||
DERP Change = 12
|
||||
ExtraRecords Change = 13
|
||||
|
||||
// Node changes.
|
||||
NodeCameOnline Change = 21
|
||||
NodeWentOffline Change = 22
|
||||
NodeRemove Change = 23
|
||||
NodeKeyExpiry Change = 24
|
||||
NodeNewOrUpdate Change = 25
|
||||
|
||||
// User changes.
|
||||
UserNewOrUpdate Change = 51
|
||||
UserRemove Change = 52
|
||||
)
|
||||
|
||||
// AlsoSelf reports whether this change should also be sent to the node itself.
|
||||
func (c Change) AlsoSelf() bool {
|
||||
switch c {
|
||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ChangeSet struct {
|
||||
Change Change
|
||||
|
||||
// SelfUpdateOnly indicates that this change should only be sent
|
||||
// to the node itself, and not to other nodes.
|
||||
// This is used for changes that are not relevant to other nodes.
|
||||
// NodeID must be set if this is true.
|
||||
SelfUpdateOnly bool
|
||||
|
||||
// NodeID if set, is the ID of the node that is being changed.
|
||||
// It must be set if this is a node change.
|
||||
NodeID types.NodeID
|
||||
|
||||
// UserID if set, is the ID of the user that is being changed.
|
||||
// It must be set if this is a user change.
|
||||
UserID types.UserID
|
||||
|
||||
// IsSubnetRouter indicates whether the node is a subnet router.
|
||||
IsSubnetRouter bool
|
||||
}
|
||||
|
||||
func (c *ChangeSet) Validate() error {
|
||||
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
|
||||
if c.NodeID == 0 {
|
||||
return errors.New("ChangeSet.NodeID must be set for node updates")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
|
||||
if c.UserID == 0 {
|
||||
return errors.New("ChangeSet.UserID must be set for user updates")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Empty reports whether the ChangeSet is empty, meaning it does not
|
||||
// represent any change.
|
||||
func (c ChangeSet) Empty() bool {
|
||||
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
|
||||
}
|
||||
|
||||
// IsFull reports whether the ChangeSet represents a full update.
|
||||
func (c ChangeSet) IsFull() bool {
|
||||
return c.Change == Full || c.Change == Policy
|
||||
}
|
||||
|
||||
func (c ChangeSet) AlsoSelf() bool {
|
||||
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||
// so we consider it as a change that should be sent to all nodes.
|
||||
if c.NodeID == 0 {
|
||||
return true
|
||||
}
|
||||
return c.Change.AlsoSelf() || c.SelfUpdateOnly
|
||||
}
|
||||
|
||||
var (
|
||||
EmptySet = ChangeSet{Change: ChangeUnknown}
|
||||
FullSet = ChangeSet{Change: Full}
|
||||
DERPSet = ChangeSet{Change: DERP}
|
||||
PolicySet = ChangeSet{Change: Policy}
|
||||
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
|
||||
)
|
||||
|
||||
func FullSelf(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Full,
|
||||
SelfUpdateOnly: true,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeAdded(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeNewOrUpdate,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeRemoved(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeRemove,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOnline(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeCameOnline,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOffline(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeWentOffline,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func KeyExpiry(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeKeyExpiry,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func UserAdded(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserNewOrUpdate,
|
||||
UserID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func UserRemoved(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserRemove,
|
||||
UserID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func PolicyChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Policy,
|
||||
}
|
||||
}
|
||||
|
||||
func DERPChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: DERP,
|
||||
}
|
||||
}
|
||||
57
hscontrol/types/change/change_string.go
Normal file
57
hscontrol/types/change/change_string.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Code generated by "stringer -type=Change"; DO NOT EDIT.
|
||||
|
||||
package change
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ChangeUnknown-0]
|
||||
_ = x[Full-9]
|
||||
_ = x[Policy-11]
|
||||
_ = x[DERP-12]
|
||||
_ = x[ExtraRecords-13]
|
||||
_ = x[NodeCameOnline-21]
|
||||
_ = x[NodeWentOffline-22]
|
||||
_ = x[NodeRemove-23]
|
||||
_ = x[NodeKeyExpiry-24]
|
||||
_ = x[NodeNewOrUpdate-25]
|
||||
_ = x[UserNewOrUpdate-51]
|
||||
_ = x[UserRemove-52]
|
||||
}
|
||||
|
||||
const (
|
||||
_Change_name_0 = "ChangeUnknown"
|
||||
_Change_name_1 = "Full"
|
||||
_Change_name_2 = "PolicyDERPExtraRecords"
|
||||
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate"
|
||||
_Change_name_4 = "UserNewOrUpdateUserRemove"
|
||||
)
|
||||
|
||||
var (
|
||||
_Change_index_2 = [...]uint8{0, 6, 10, 22}
|
||||
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67}
|
||||
_Change_index_4 = [...]uint8{0, 15, 25}
|
||||
)
|
||||
|
||||
func (i Change) String() string {
|
||||
switch {
|
||||
case i == 0:
|
||||
return _Change_name_0
|
||||
case i == 9:
|
||||
return _Change_name_1
|
||||
case 11 <= i && i <= 13:
|
||||
i -= 11
|
||||
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
|
||||
case 21 <= i && i <= 25:
|
||||
i -= 21
|
||||
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
|
||||
case 51 <= i && i <= 52:
|
||||
i -= 51
|
||||
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
|
||||
default:
|
||||
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,16 @@
|
||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||
|
||||
//go:generate go tool viewer --type=User,Node,PreAuthKey
|
||||
package types
|
||||
|
||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/ctxkey"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyOriginKey = ctxkey.New("notify.origin", "")
|
||||
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
|
||||
)
|
||||
|
||||
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
|
||||
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
|
||||
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
|
||||
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
|
||||
return ctx2
|
||||
}
|
||||
|
||||
const RegistrationIDLength = 24
|
||||
|
||||
type RegistrationID string
|
||||
@@ -199,3 +187,20 @@ type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||
func DefaultBatcherWorkers() int {
|
||||
return DefaultBatcherWorkersFor(runtime.NumCPU())
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
|
||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||
func DefaultBatcherWorkersFor(cpuCount int) int {
|
||||
defaultWorkers := (cpuCount * 3) / 4
|
||||
if defaultWorkers < 1 {
|
||||
defaultWorkers = 1
|
||||
}
|
||||
|
||||
return defaultWorkers
|
||||
}
|
||||
|
||||
36
hscontrol/types/common_test.go
Normal file
36
hscontrol/types/common_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultBatcherWorkersFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
cpuCount int
|
||||
expected int
|
||||
}{
|
||||
{1, 1}, // (1*3)/4 = 0, should be minimum 1
|
||||
{2, 1}, // (2*3)/4 = 1
|
||||
{4, 3}, // (4*3)/4 = 3
|
||||
{8, 6}, // (8*3)/4 = 6
|
||||
{12, 9}, // (12*3)/4 = 9
|
||||
{16, 12}, // (16*3)/4 = 12
|
||||
{20, 15}, // (20*3)/4 = 15
|
||||
{24, 18}, // (24*3)/4 = 18
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := DefaultBatcherWorkersFor(test.cpuCount)
|
||||
if result != test.expected {
|
||||
t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultBatcherWorkers(t *testing.T) {
|
||||
// Just verify it returns a valid value (>= 1)
|
||||
result := DefaultBatcherWorkers()
|
||||
if result < 1 {
|
||||
t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result)
|
||||
}
|
||||
}
|
||||
@@ -234,6 +234,7 @@ type Tuning struct {
|
||||
NotifierSendTimeout time.Duration
|
||||
BatchChangeDelay time.Duration
|
||||
NodeMapSessionBufferedChanSize int
|
||||
BatcherWorkers int
|
||||
}
|
||||
|
||||
func validatePKCEMethod(method string) error {
|
||||
@@ -305,7 +306,7 @@ func LoadConfig(path string, isFile bool) error {
|
||||
viper.SetDefault("grpc_listen_addr", ":50443")
|
||||
viper.SetDefault("grpc_allow_insecure", false)
|
||||
|
||||
viper.SetDefault("cli.timeout", "30s")
|
||||
viper.SetDefault("cli.timeout", "5s")
|
||||
viper.SetDefault("cli.insecure", false)
|
||||
|
||||
viper.SetDefault("database.postgres.ssl", false)
|
||||
@@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) {
|
||||
NodeMapSessionBufferedChanSize: viper.GetInt(
|
||||
"tuning.node_mapsession_buffered_chan_size",
|
||||
),
|
||||
BatcherWorkers: func() int {
|
||||
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||
return workers
|
||||
}
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
return routes
|
||||
}
|
||||
|
||||
// IsSubnetRouter reports if the node has any subnet routes.
|
||||
func (node *Node) IsSubnetRouter() bool {
|
||||
return len(node.SubnetRoutes()) > 0
|
||||
}
|
||||
|
||||
func (node *Node) String() string {
|
||||
return node.Hostname
|
||||
}
|
||||
@@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix {
|
||||
return v.ж.SubnetRoutes()
|
||||
}
|
||||
|
||||
func (v NodeView) IsSubnetRouter() bool {
|
||||
if !v.Valid() {
|
||||
return false
|
||||
}
|
||||
return v.ж.IsSubnetRouter()
|
||||
}
|
||||
|
||||
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
if !v.Valid() {
|
||||
return
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type PAKError string
|
||||
|
||||
func (e PAKError) Error() string { return string(e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
@@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error {
|
||||
if pak == nil {
|
||||
return PAKError("invalid authkey")
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", pak.Key).
|
||||
Bool("hasExpiration", pak.Expiration != nil).
|
||||
Time("expiration", func() time.Time {
|
||||
if pak.Expiration != nil {
|
||||
return *pak.Expiration
|
||||
}
|
||||
return time.Time{}
|
||||
}()).
|
||||
Time("now", time.Now()).
|
||||
Bool("reusable", pak.Reusable).
|
||||
Bool("used", pak.Used).
|
||||
Msg("PreAuthKey.Validate: checking key")
|
||||
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return PAKError("authkey expired")
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
func TestCheckForFQDNRules(t *testing.T) {
|
||||
@@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) {
|
||||
func TestMagicDNSRootDomains100(t *testing.T) {
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "64.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "100.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "127.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa.")))
|
||||
}
|
||||
|
||||
func TestMagicDNSRootDomains172(t *testing.T) {
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "0.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "255.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa.")))
|
||||
}
|
||||
|
||||
// Happens when netmask is a multiple of 4 bits (sounds likely).
|
||||
|
||||
@@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
|
||||
// Parse latencies
|
||||
for j := 5; j <= 7; j++ {
|
||||
if matches[j] != "" {
|
||||
if j < len(matches) && matches[j] != "" {
|
||||
ms, err := strconv.ParseFloat(matches[j], 64)
|
||||
if err != nil {
|
||||
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
|
||||
|
||||
@@ -88,7 +88,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match before logout count")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
|
||||
for _, node := range listNodes {
|
||||
@@ -123,7 +123,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match after HTTPS reconnection")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
|
||||
for _, node := range listNodes {
|
||||
@@ -161,7 +161,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
}
|
||||
@@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
"--user",
|
||||
strconv.FormatUint(userMap[userName].GetId(), 10),
|
||||
"expire",
|
||||
key.Key,
|
||||
key.GetKey(),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@@ -95,7 +95,7 @@ func TestUserCommand(t *testing.T) {
|
||||
"users",
|
||||
"rename",
|
||||
"--output=json",
|
||||
fmt.Sprintf("--user=%d", listUsers[1].GetId()),
|
||||
fmt.Sprintf("--identifier=%d", listUsers[1].GetId()),
|
||||
"--new-name=newname",
|
||||
},
|
||||
)
|
||||
@@ -161,7 +161,7 @@ func TestUserCommand(t *testing.T) {
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
"--user=1",
|
||||
"--identifier=1",
|
||||
},
|
||||
&listByID,
|
||||
)
|
||||
@@ -187,7 +187,7 @@ func TestUserCommand(t *testing.T) {
|
||||
"destroy",
|
||||
"--force",
|
||||
// Delete "user1"
|
||||
"--user=1",
|
||||
"--identifier=1",
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
@@ -354,10 +354,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Sort tags for consistent comparison
|
||||
tags := listedPreAuthKeys[index].GetAclTags()
|
||||
slices.Sort(tags)
|
||||
assert.Equal(t, []string{"tag:test1", "tag:test2"}, tags)
|
||||
assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
|
||||
}
|
||||
|
||||
// Test key expiry
|
||||
@@ -872,7 +869,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"tag",
|
||||
"--node", "1",
|
||||
"-i", "1",
|
||||
"-t", "tag:test",
|
||||
"--output", "json",
|
||||
},
|
||||
@@ -887,7 +884,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"tag",
|
||||
"--node", "2",
|
||||
"-i", "2",
|
||||
"-t", "wrong-tag",
|
||||
"--output", "json",
|
||||
},
|
||||
@@ -1262,7 +1259,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"delete",
|
||||
"--node",
|
||||
"--identifier",
|
||||
// Delete the last added machine
|
||||
"4",
|
||||
"--output",
|
||||
@@ -1388,7 +1385,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"expire",
|
||||
"--node",
|
||||
"--identifier",
|
||||
strconv.FormatUint(listAll[idx].GetId(), 10),
|
||||
},
|
||||
)
|
||||
@@ -1514,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"rename",
|
||||
"--node",
|
||||
"--identifier",
|
||||
strconv.FormatUint(listAll[idx].GetId(), 10),
|
||||
fmt.Sprintf("newnode-%d", idx+1),
|
||||
},
|
||||
@@ -1552,7 +1549,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"rename",
|
||||
"--node",
|
||||
"--identifier",
|
||||
strconv.FormatUint(listAll[4].GetId(), 10),
|
||||
strings.Repeat("t", 64),
|
||||
},
|
||||
@@ -1652,7 +1649,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"move",
|
||||
"--node",
|
||||
"--identifier",
|
||||
strconv.FormatUint(node.GetId(), 10),
|
||||
"--user",
|
||||
strconv.FormatUint(userMap["new-user"].GetId(), 10),
|
||||
@@ -1690,7 +1687,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"move",
|
||||
"--node",
|
||||
"--identifier",
|
||||
nodeID,
|
||||
"--user",
|
||||
"999",
|
||||
@@ -1711,7 +1708,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"move",
|
||||
"--node",
|
||||
"--identifier",
|
||||
nodeID,
|
||||
"--user",
|
||||
strconv.FormatUint(userMap["old-user"].GetId(), 10),
|
||||
@@ -1730,7 +1727,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
"headscale",
|
||||
"nodes",
|
||||
"move",
|
||||
"--node",
|
||||
"--identifier",
|
||||
nodeID,
|
||||
"--user",
|
||||
strconv.FormatUint(userMap["old-user"].GetId(), 10),
|
||||
|
||||
@@ -1,423 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDebugCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"debug-user"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clidebug"))
|
||||
assertNoErr(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
t.Run("test_debug_help", func(t *testing.T) {
|
||||
// Test debug command help
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"--help",
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Help text should contain expected information
|
||||
assert.Contains(t, result, "debug", "help should mention debug command")
|
||||
assert.Contains(t, result, "debugging and testing", "help should contain command description")
|
||||
assert.Contains(t, result, "create-node", "help should mention create-node subcommand")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_help", func(t *testing.T) {
|
||||
// Test debug create-node command help
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--help",
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Help text should contain expected information
|
||||
assert.Contains(t, result, "create-node", "help should mention create-node command")
|
||||
assert.Contains(t, result, "name", "help should mention name flag")
|
||||
assert.Contains(t, result, "user", "help should mention user flag")
|
||||
assert.Contains(t, result, "key", "help should mention key flag")
|
||||
assert.Contains(t, result, "route", "help should mention route flag")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugCreateNodeCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"debug-create-user"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clidebugcreate"))
|
||||
assertNoErr(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Create a user first
|
||||
user := spec.Users[0]
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"users",
|
||||
"create",
|
||||
user,
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
t.Run("test_debug_create_node_basic", func(t *testing.T) {
|
||||
// Test basic debug create-node functionality
|
||||
nodeName := "debug-test-node"
|
||||
// Generate a mock registration key (64 hex chars with nodekey prefix)
|
||||
registrationKey := "nodekey:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Should output node creation confirmation
|
||||
assert.Contains(t, result, "Node created", "should confirm node creation")
|
||||
assert.Contains(t, result, nodeName, "should mention the created node name")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_with_routes", func(t *testing.T) {
|
||||
// Test debug create-node with advertised routes
|
||||
nodeName := "debug-route-node"
|
||||
registrationKey := "nodekey:abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890"
|
||||
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
"--route", "10.0.0.0/24",
|
||||
"--route", "192.168.1.0/24",
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Should output node creation confirmation
|
||||
assert.Contains(t, result, "Node created", "should confirm node creation")
|
||||
assert.Contains(t, result, nodeName, "should mention the created node name")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_json_output", func(t *testing.T) {
|
||||
// Test debug create-node with JSON output
|
||||
nodeName := "debug-json-node"
|
||||
registrationKey := "nodekey:fedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321"
|
||||
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
"--output", "json",
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Should produce valid JSON output
|
||||
var node v1.Node
|
||||
err = json.Unmarshal([]byte(result), &node)
|
||||
assert.NoError(t, err, "debug create-node should produce valid JSON output")
|
||||
assert.Equal(t, nodeName, node.GetName(), "created node should have correct name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugCreateNodeCommandValidation(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"debug-validation-user"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clidebugvalidation"))
|
||||
assertNoErr(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Create a user first
|
||||
user := spec.Users[0]
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"users",
|
||||
"create",
|
||||
user,
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
t.Run("test_debug_create_node_missing_name", func(t *testing.T) {
|
||||
// Test debug create-node with missing name flag
|
||||
registrationKey := "nodekey:1111111111111111111111111111111111111111111111111111111111111111"
|
||||
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
},
|
||||
)
|
||||
// Should fail for missing required name flag
|
||||
assert.Error(t, err, "should fail for missing name flag")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_missing_user", func(t *testing.T) {
|
||||
// Test debug create-node with missing user flag
|
||||
registrationKey := "nodekey:2222222222222222222222222222222222222222222222222222222222222222"
|
||||
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", "test-node",
|
||||
"--key", registrationKey,
|
||||
},
|
||||
)
|
||||
// Should fail for missing required user flag
|
||||
assert.Error(t, err, "should fail for missing user flag")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_missing_key", func(t *testing.T) {
|
||||
// Test debug create-node with missing key flag
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", "test-node",
|
||||
"--user", user,
|
||||
},
|
||||
)
|
||||
// Should fail for missing required key flag
|
||||
assert.Error(t, err, "should fail for missing key flag")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_invalid_key", func(t *testing.T) {
|
||||
// Test debug create-node with invalid registration key format
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", "test-node",
|
||||
"--user", user,
|
||||
"--key", "invalid-key-format",
|
||||
},
|
||||
)
|
||||
// Should fail for invalid key format
|
||||
assert.Error(t, err, "should fail for invalid key format")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_nonexistent_user", func(t *testing.T) {
|
||||
// Test debug create-node with non-existent user
|
||||
registrationKey := "nodekey:3333333333333333333333333333333333333333333333333333333333333333"
|
||||
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", "test-node",
|
||||
"--user", "nonexistent-user",
|
||||
"--key", registrationKey,
|
||||
},
|
||||
)
|
||||
// Should fail for non-existent user
|
||||
assert.Error(t, err, "should fail for non-existent user")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_duplicate_name", func(t *testing.T) {
|
||||
// Test debug create-node with duplicate node name
|
||||
nodeName := "duplicate-node"
|
||||
registrationKey1 := "nodekey:4444444444444444444444444444444444444444444444444444444444444444"
|
||||
registrationKey2 := "nodekey:5555555555555555555555555555555555555555555555555555555555555555"
|
||||
|
||||
// Create first node
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey1,
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Try to create second node with same name
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey2,
|
||||
},
|
||||
)
|
||||
// Should fail for duplicate node name
|
||||
assert.Error(t, err, "should fail for duplicate node name")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugCreateNodeCommandEdgeCases(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"debug-edge-user"},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clidebugedge"))
|
||||
assertNoErr(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Create a user first
|
||||
user := spec.Users[0]
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"users",
|
||||
"create",
|
||||
user,
|
||||
},
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
t.Run("test_debug_create_node_invalid_route", func(t *testing.T) {
|
||||
// Test debug create-node with invalid route format
|
||||
nodeName := "invalid-route-node"
|
||||
registrationKey := "nodekey:6666666666666666666666666666666666666666666666666666666666666666"
|
||||
|
||||
_, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
"--route", "invalid-cidr",
|
||||
},
|
||||
)
|
||||
// Should handle invalid route format gracefully
|
||||
assert.Error(t, err, "should fail for invalid route format")
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_empty_route", func(t *testing.T) {
|
||||
// Test debug create-node with empty route
|
||||
nodeName := "empty-route-node"
|
||||
registrationKey := "nodekey:7777777777777777777777777777777777777777777777777777777777777777"
|
||||
|
||||
result, err := headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", nodeName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
"--route", "",
|
||||
},
|
||||
)
|
||||
// Should handle empty route (either succeed or fail gracefully)
|
||||
if err == nil {
|
||||
assert.Contains(t, result, "Node created", "should confirm node creation if empty route is allowed")
|
||||
} else {
|
||||
assert.Error(t, err, "should fail gracefully for empty route")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test_debug_create_node_very_long_name", func(t *testing.T) {
|
||||
// Test debug create-node with very long node name
|
||||
longName := fmt.Sprintf("very-long-node-name-%s", "x")
|
||||
for i := 0; i < 10; i++ {
|
||||
longName += "-very-long-segment"
|
||||
}
|
||||
registrationKey := "nodekey:8888888888888888888888888888888888888888888888888888888888888888"
|
||||
|
||||
_, _ = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", longName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
},
|
||||
)
|
||||
// Should handle very long names (either succeed or fail gracefully)
|
||||
assert.NotPanics(t, func() {
|
||||
headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"debug",
|
||||
"create-node",
|
||||
"--name", longName,
|
||||
"--user", user,
|
||||
"--key", registrationKey,
|
||||
},
|
||||
)
|
||||
}, "should handle very long node names gracefully")
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user