Compare commits

..

1 Commits

Author SHA1 Message Date
github-actions[bot]
f2a1a6f5ad flake.lock: Update
Flake lock file updates:

• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/ac055f3' (2026-02-15)
  → 'github:NixOS/nixpkgs/d1c15b7' (2026-02-16)
2026-02-22 00:31:06 +00:00
55 changed files with 665 additions and 4087 deletions

View File

@@ -254,12 +254,6 @@ jobs:
- TestSSHIsBlockedInACL
- TestSSHUserOnlyIsolation
- TestSSHAutogroupSelf
- TestSSHOneUserToOneCheckModeCLI
- TestSSHOneUserToOneCheckModeOIDC
- TestSSHCheckModeUnapprovedTimeout
- TestSSHCheckModeCheckPeriodCLI
- TestSSHCheckModeAutoApprove
- TestSSHCheckModeNegativeCLI
- TestTagsAuthKeyWithTagRequestDifferentTag
- TestTagsAuthKeyWithTagNoAdvertiseFlag
- TestTagsAuthKeyWithTagCannotAddViaCLI

View File

@@ -48,7 +48,7 @@ repos:
# golangci-lint for Go code quality
- id: golangci-lint
name: golangci-lint
entry: nix develop --command -- golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
language: system
types: [go]
pass_filenames: false

View File

@@ -11,20 +11,6 @@ to understand how the packet filter should be generated. We discovered a few dif
overall our implementation was very close.
[#3036](https://github.com/juanfont/headscale/pull/3036)
### SSH check action
SSH rules with `"action": "check"` are now supported. When a client initiates a SSH connection to a node
with a `check` action policy, the user is prompted to authenticate via OIDC or CLI approval before access
is granted.
A new `headscale auth` CLI command group supports the approval flow:
- `headscale auth approve --auth-id <id>` approves a pending authentication request (SSH check or web auth)
- `headscale auth reject --auth-id <id>` rejects a pending authentication request
- `headscale auth register --auth-id <id> --user <user>` registers a node (replaces deprecated `headscale nodes register`)
[#1850](https://github.com/juanfont/headscale/pull/1850)
### BREAKING
- **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036)
@@ -40,8 +26,6 @@ A new `headscale auth` CLI command group supports the approval flow:
- **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
- Previously, `proto:icmp` included both ICMPv4 and ICMPv6
- Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6
- **CLI**: `headscale nodes register` is deprecated in favour of `headscale auth register --auth-id <id> --user <user>` [#1850](https://github.com/juanfont/headscale/pull/1850)
- The old command continues to work but will be removed in a future release
### Changes
@@ -51,11 +35,6 @@ A new `headscale auth` CLI command group supports the approval flow:
- **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036)
- Remove deprecated `--namespace` flag from `nodes list`, `nodes register`, and `debug create-node` commands (use `--user` instead) [#3093](https://github.com/juanfont/headscale/pull/3093)
- Remove deprecated `namespace`/`ns` command aliases for `users` and `machine`/`machines` aliases for `nodes` [#3093](https://github.com/juanfont/headscale/pull/3093)
- Add SSH `check` action support with OIDC and CLI-based approval flows [#1850](https://github.com/juanfont/headscale/pull/1850)
- Add `headscale auth register`, `headscale auth approve`, and `headscale auth reject` CLI commands [#1850](https://github.com/juanfont/headscale/pull/1850)
- Deprecate `headscale nodes register --key` in favour of `headscale auth register --auth-id` [#1850](https://github.com/juanfont/headscale/pull/1850)
- Generalise auth templates into reusable `AuthSuccess` and `AuthWeb` components [#1850](https://github.com/juanfont/headscale/pull/1850)
- Unify auth pipeline with `AuthVerdict` type, supporting registration, reauthentication, and SSH checks [#1850](https://github.com/juanfont/headscale/pull/1850)
## 0.28.0 (2026-02-04)

View File

@@ -1,93 +0,0 @@
package cli
import (
"context"
"fmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
)
func init() {
rootCmd.AddCommand(authCmd)
authRegisterCmd.Flags().StringP("user", "u", "", "User")
authRegisterCmd.Flags().String("auth-id", "", "Auth ID")
mustMarkRequired(authRegisterCmd, "user", "auth-id")
authCmd.AddCommand(authRegisterCmd)
authApproveCmd.Flags().String("auth-id", "", "Auth ID")
mustMarkRequired(authApproveCmd, "auth-id")
authCmd.AddCommand(authApproveCmd)
authRejectCmd.Flags().String("auth-id", "", "Auth ID")
mustMarkRequired(authRejectCmd, "auth-id")
authCmd.AddCommand(authRejectCmd)
}
var authCmd = &cobra.Command{
Use: "auth",
Short: "Manage node authentication and approval",
}
var authRegisterCmd = &cobra.Command{
Use: "register",
Short: "Register a node to your network",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
user, _ := cmd.Flags().GetString("user")
authID, _ := cmd.Flags().GetString("auth-id")
request := &v1.AuthRegisterRequest{
AuthId: authID,
User: user,
}
response, err := client.AuthRegister(ctx, request)
if err != nil {
return fmt.Errorf("registering node: %w", err)
}
return printOutput(
cmd,
response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()))
}),
}
var authApproveCmd = &cobra.Command{
Use: "approve",
Short: "Approve a pending authentication request",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
authID, _ := cmd.Flags().GetString("auth-id")
request := &v1.AuthApproveRequest{
AuthId: authID,
}
response, err := client.AuthApprove(ctx, request)
if err != nil {
return fmt.Errorf("approving auth request: %w", err)
}
return printOutput(cmd, response, "Auth request approved")
}),
}
var authRejectCmd = &cobra.Command{
Use: "reject",
Short: "Reject a pending authentication request",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
authID, _ := cmd.Flags().GetString("auth-id")
request := &v1.AuthRejectRequest{
AuthId: authID,
}
response, err := client.AuthReject(ctx, request)
if err != nil {
return fmt.Errorf("rejecting auth request: %w", err)
}
return printOutput(cmd, response, "Auth request rejected")
}),
}

View File

@@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{
name, _ := cmd.Flags().GetString("name")
registrationID, _ := cmd.Flags().GetString("key")
_, err := types.AuthIDFromString(registrationID)
_, err := types.RegistrationIDFromString(registrationID)
if err != nil {
return fmt.Errorf("parsing machine key: %w", err)
}

View File

@@ -64,9 +64,8 @@ var nodeCmd = &cobra.Command{
}
var registerNodeCmd = &cobra.Command{
Use: "register",
Short: "Registers a node to your network",
Deprecated: "use 'headscale auth register --auth-id <id> --user <user>' instead",
Use: "register",
Short: "Registers a node to your network",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
user, _ := cmd.Flags().GetString("user")
registrationID, _ := cmd.Flags().GetString("key")

6
flake.lock generated
View File

@@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1771177547,
"narHash": "sha256-trTtk3WTOHz7hSw89xIIvahkgoFJYQ0G43IlqprFoMA=",
"lastModified": 1771207753,
"narHash": "sha256-b9uG8yN50DRQ6A7JdZBfzq718ryYrlmGgqkRm9OOwCE=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ac055f38c798b0d87695240c7b761b82fc7e5bc2",
"rev": "d1c15b7d5806069da59e819999d70e1cec0760bf",
"type": "github"
},
"original": {

View File

@@ -27,7 +27,7 @@
let
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
buildGo = pkgs.buildGo126Module;
vendorHash = "sha256-oUN53ELb3+xn4yA7lEfXyT2c7NxbQC6RtbkGVq6+RLU=";
vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
in
{
headscale = buildGo {
@@ -135,6 +135,11 @@
};
};
# The package uses buildGo125Module, not the convention.
# goreleaser = prev.goreleaser.override {
# buildGoModule = buildGo;
# };
gotestsum = prev.gotestsum.override {
buildGoModule = buildGo;
};
@@ -147,9 +152,9 @@
buildGoModule = buildGo;
};
gopls = prev.gopls.override {
buildGoLatestModule = buildGo;
};
# gopls = prev.gopls.override {
# buildGoModule = buildGo;
# };
};
}
// flake-utils.lib.eachDefaultSystem

View File

@@ -1,351 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc (unknown)
// source: headscale/v1/auth.proto
package v1
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type AuthRegisterRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"`
AuthId string `protobuf:"bytes,2,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthRegisterRequest) Reset() {
*x = AuthRegisterRequest{}
mi := &file_headscale_v1_auth_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthRegisterRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthRegisterRequest) ProtoMessage() {}
func (x *AuthRegisterRequest) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthRegisterRequest.ProtoReflect.Descriptor instead.
func (*AuthRegisterRequest) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{0}
}
func (x *AuthRegisterRequest) GetUser() string {
if x != nil {
return x.User
}
return ""
}
func (x *AuthRegisterRequest) GetAuthId() string {
if x != nil {
return x.AuthId
}
return ""
}
type AuthRegisterResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthRegisterResponse) Reset() {
*x = AuthRegisterResponse{}
mi := &file_headscale_v1_auth_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthRegisterResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthRegisterResponse) ProtoMessage() {}
func (x *AuthRegisterResponse) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthRegisterResponse.ProtoReflect.Descriptor instead.
func (*AuthRegisterResponse) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{1}
}
func (x *AuthRegisterResponse) GetNode() *Node {
if x != nil {
return x.Node
}
return nil
}
type AuthApproveRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
AuthId string `protobuf:"bytes,1,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthApproveRequest) Reset() {
*x = AuthApproveRequest{}
mi := &file_headscale_v1_auth_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthApproveRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthApproveRequest) ProtoMessage() {}
func (x *AuthApproveRequest) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthApproveRequest.ProtoReflect.Descriptor instead.
func (*AuthApproveRequest) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{2}
}
func (x *AuthApproveRequest) GetAuthId() string {
if x != nil {
return x.AuthId
}
return ""
}
type AuthApproveResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthApproveResponse) Reset() {
*x = AuthApproveResponse{}
mi := &file_headscale_v1_auth_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthApproveResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthApproveResponse) ProtoMessage() {}
func (x *AuthApproveResponse) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthApproveResponse.ProtoReflect.Descriptor instead.
func (*AuthApproveResponse) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{3}
}
type AuthRejectRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
AuthId string `protobuf:"bytes,1,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthRejectRequest) Reset() {
*x = AuthRejectRequest{}
mi := &file_headscale_v1_auth_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthRejectRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthRejectRequest) ProtoMessage() {}
func (x *AuthRejectRequest) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthRejectRequest.ProtoReflect.Descriptor instead.
func (*AuthRejectRequest) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{4}
}
func (x *AuthRejectRequest) GetAuthId() string {
if x != nil {
return x.AuthId
}
return ""
}
type AuthRejectResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthRejectResponse) Reset() {
*x = AuthRejectResponse{}
mi := &file_headscale_v1_auth_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthRejectResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthRejectResponse) ProtoMessage() {}
func (x *AuthRejectResponse) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthRejectResponse.ProtoReflect.Descriptor instead.
func (*AuthRejectResponse) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{5}
}
var File_headscale_v1_auth_proto protoreflect.FileDescriptor
const file_headscale_v1_auth_proto_rawDesc = "" +
"\n" +
"\x17headscale/v1/auth.proto\x12\fheadscale.v1\x1a\x17headscale/v1/node.proto\"B\n" +
"\x13AuthRegisterRequest\x12\x12\n" +
"\x04user\x18\x01 \x01(\tR\x04user\x12\x17\n" +
"\aauth_id\x18\x02 \x01(\tR\x06authId\">\n" +
"\x14AuthRegisterResponse\x12&\n" +
"\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"-\n" +
"\x12AuthApproveRequest\x12\x17\n" +
"\aauth_id\x18\x01 \x01(\tR\x06authId\"\x15\n" +
"\x13AuthApproveResponse\",\n" +
"\x11AuthRejectRequest\x12\x17\n" +
"\aauth_id\x18\x01 \x01(\tR\x06authId\"\x14\n" +
"\x12AuthRejectResponseB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3"
var (
file_headscale_v1_auth_proto_rawDescOnce sync.Once
file_headscale_v1_auth_proto_rawDescData []byte
)
func file_headscale_v1_auth_proto_rawDescGZIP() []byte {
file_headscale_v1_auth_proto_rawDescOnce.Do(func() {
file_headscale_v1_auth_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc)))
})
return file_headscale_v1_auth_proto_rawDescData
}
var file_headscale_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_headscale_v1_auth_proto_goTypes = []any{
(*AuthRegisterRequest)(nil), // 0: headscale.v1.AuthRegisterRequest
(*AuthRegisterResponse)(nil), // 1: headscale.v1.AuthRegisterResponse
(*AuthApproveRequest)(nil), // 2: headscale.v1.AuthApproveRequest
(*AuthApproveResponse)(nil), // 3: headscale.v1.AuthApproveResponse
(*AuthRejectRequest)(nil), // 4: headscale.v1.AuthRejectRequest
(*AuthRejectResponse)(nil), // 5: headscale.v1.AuthRejectResponse
(*Node)(nil), // 6: headscale.v1.Node
}
var file_headscale_v1_auth_proto_depIdxs = []int32{
6, // 0: headscale.v1.AuthRegisterResponse.node:type_name -> headscale.v1.Node
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_headscale_v1_auth_proto_init() }
func file_headscale_v1_auth_proto_init() {
if File_headscale_v1_auth_proto != nil {
return
}
file_headscale_v1_node_proto_init()
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc)),
NumEnums: 0,
NumMessages: 6,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_headscale_v1_auth_proto_goTypes,
DependencyIndexes: file_headscale_v1_auth_proto_depIdxs,
MessageInfos: file_headscale_v1_auth_proto_msgTypes,
}.Build()
File_headscale_v1_auth_proto = out.File
file_headscale_v1_auth_proto_goTypes = nil
file_headscale_v1_auth_proto_depIdxs = nil
}

View File

@@ -106,10 +106,10 @@ var File_headscale_v1_headscale_proto protoreflect.FileDescriptor
const file_headscale_v1_headscale_proto_rawDesc = "" +
"\n" +
"\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x17headscale/v1/auth.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" +
"\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" +
"\rHealthRequest\"E\n" +
"\x0eHealthResponse\x123\n" +
"\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\xeb\x19\n" +
"\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\x8c\x17\n" +
"\x10HeadscaleService\x12h\n" +
"\n" +
"CreateUser\x12\x1f.headscale.v1.CreateUserRequest\x1a .headscale.v1.CreateUserResponse\"\x17\x82\xd3\xe4\x93\x02\x11:\x01*\"\f/api/v1/user\x12\x80\x01\n" +
@@ -134,11 +134,7 @@ const file_headscale_v1_headscale_proto_rawDesc = "" +
"\n" +
"RenameNode\x12\x1f.headscale.v1.RenameNodeRequest\x1a .headscale.v1.RenameNodeResponse\"0\x82\xd3\xe4\x93\x02*\"(/api/v1/node/{node_id}/rename/{new_name}\x12b\n" +
"\tListNodes\x12\x1e.headscale.v1.ListNodesRequest\x1a\x1f.headscale.v1.ListNodesResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\x12\f/api/v1/node\x12\x80\x01\n" +
"\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12w\n" +
"\fAuthRegister\x12!.headscale.v1.AuthRegisterRequest\x1a\".headscale.v1.AuthRegisterResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/auth/register\x12s\n" +
"\vAuthApprove\x12 .headscale.v1.AuthApproveRequest\x1a!.headscale.v1.AuthApproveResponse\"\x1f\x82\xd3\xe4\x93\x02\x19:\x01*\"\x14/api/v1/auth/approve\x12o\n" +
"\n" +
"AuthReject\x12\x1f.headscale.v1.AuthRejectRequest\x1a .headscale.v1.AuthRejectResponse\"\x1e\x82\xd3\xe4\x93\x02\x18:\x01*\"\x13/api/v1/auth/reject\x12p\n" +
"\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12p\n" +
"\fCreateApiKey\x12!.headscale.v1.CreateApiKeyRequest\x1a\".headscale.v1.CreateApiKeyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\"\x0e/api/v1/apikey\x12w\n" +
"\fExpireApiKey\x12!.headscale.v1.ExpireApiKeyRequest\x1a\".headscale.v1.ExpireApiKeyResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/apikey/expire\x12j\n" +
"\vListApiKeys\x12 .headscale.v1.ListApiKeysRequest\x1a!.headscale.v1.ListApiKeysResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/apikey\x12v\n" +
@@ -181,42 +177,36 @@ var file_headscale_v1_headscale_proto_goTypes = []any{
(*RenameNodeRequest)(nil), // 17: headscale.v1.RenameNodeRequest
(*ListNodesRequest)(nil), // 18: headscale.v1.ListNodesRequest
(*BackfillNodeIPsRequest)(nil), // 19: headscale.v1.BackfillNodeIPsRequest
(*AuthRegisterRequest)(nil), // 20: headscale.v1.AuthRegisterRequest
(*AuthApproveRequest)(nil), // 21: headscale.v1.AuthApproveRequest
(*AuthRejectRequest)(nil), // 22: headscale.v1.AuthRejectRequest
(*CreateApiKeyRequest)(nil), // 23: headscale.v1.CreateApiKeyRequest
(*ExpireApiKeyRequest)(nil), // 24: headscale.v1.ExpireApiKeyRequest
(*ListApiKeysRequest)(nil), // 25: headscale.v1.ListApiKeysRequest
(*DeleteApiKeyRequest)(nil), // 26: headscale.v1.DeleteApiKeyRequest
(*GetPolicyRequest)(nil), // 27: headscale.v1.GetPolicyRequest
(*SetPolicyRequest)(nil), // 28: headscale.v1.SetPolicyRequest
(*CreateUserResponse)(nil), // 29: headscale.v1.CreateUserResponse
(*RenameUserResponse)(nil), // 30: headscale.v1.RenameUserResponse
(*DeleteUserResponse)(nil), // 31: headscale.v1.DeleteUserResponse
(*ListUsersResponse)(nil), // 32: headscale.v1.ListUsersResponse
(*CreatePreAuthKeyResponse)(nil), // 33: headscale.v1.CreatePreAuthKeyResponse
(*ExpirePreAuthKeyResponse)(nil), // 34: headscale.v1.ExpirePreAuthKeyResponse
(*DeletePreAuthKeyResponse)(nil), // 35: headscale.v1.DeletePreAuthKeyResponse
(*ListPreAuthKeysResponse)(nil), // 36: headscale.v1.ListPreAuthKeysResponse
(*DebugCreateNodeResponse)(nil), // 37: headscale.v1.DebugCreateNodeResponse
(*GetNodeResponse)(nil), // 38: headscale.v1.GetNodeResponse
(*SetTagsResponse)(nil), // 39: headscale.v1.SetTagsResponse
(*SetApprovedRoutesResponse)(nil), // 40: headscale.v1.SetApprovedRoutesResponse
(*RegisterNodeResponse)(nil), // 41: headscale.v1.RegisterNodeResponse
(*DeleteNodeResponse)(nil), // 42: headscale.v1.DeleteNodeResponse
(*ExpireNodeResponse)(nil), // 43: headscale.v1.ExpireNodeResponse
(*RenameNodeResponse)(nil), // 44: headscale.v1.RenameNodeResponse
(*ListNodesResponse)(nil), // 45: headscale.v1.ListNodesResponse
(*BackfillNodeIPsResponse)(nil), // 46: headscale.v1.BackfillNodeIPsResponse
(*AuthRegisterResponse)(nil), // 47: headscale.v1.AuthRegisterResponse
(*AuthApproveResponse)(nil), // 48: headscale.v1.AuthApproveResponse
(*AuthRejectResponse)(nil), // 49: headscale.v1.AuthRejectResponse
(*CreateApiKeyResponse)(nil), // 50: headscale.v1.CreateApiKeyResponse
(*ExpireApiKeyResponse)(nil), // 51: headscale.v1.ExpireApiKeyResponse
(*ListApiKeysResponse)(nil), // 52: headscale.v1.ListApiKeysResponse
(*DeleteApiKeyResponse)(nil), // 53: headscale.v1.DeleteApiKeyResponse
(*GetPolicyResponse)(nil), // 54: headscale.v1.GetPolicyResponse
(*SetPolicyResponse)(nil), // 55: headscale.v1.SetPolicyResponse
(*CreateApiKeyRequest)(nil), // 20: headscale.v1.CreateApiKeyRequest
(*ExpireApiKeyRequest)(nil), // 21: headscale.v1.ExpireApiKeyRequest
(*ListApiKeysRequest)(nil), // 22: headscale.v1.ListApiKeysRequest
(*DeleteApiKeyRequest)(nil), // 23: headscale.v1.DeleteApiKeyRequest
(*GetPolicyRequest)(nil), // 24: headscale.v1.GetPolicyRequest
(*SetPolicyRequest)(nil), // 25: headscale.v1.SetPolicyRequest
(*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse
(*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse
(*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse
(*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse
(*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse
(*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse
(*DeletePreAuthKeyResponse)(nil), // 32: headscale.v1.DeletePreAuthKeyResponse
(*ListPreAuthKeysResponse)(nil), // 33: headscale.v1.ListPreAuthKeysResponse
(*DebugCreateNodeResponse)(nil), // 34: headscale.v1.DebugCreateNodeResponse
(*GetNodeResponse)(nil), // 35: headscale.v1.GetNodeResponse
(*SetTagsResponse)(nil), // 36: headscale.v1.SetTagsResponse
(*SetApprovedRoutesResponse)(nil), // 37: headscale.v1.SetApprovedRoutesResponse
(*RegisterNodeResponse)(nil), // 38: headscale.v1.RegisterNodeResponse
(*DeleteNodeResponse)(nil), // 39: headscale.v1.DeleteNodeResponse
(*ExpireNodeResponse)(nil), // 40: headscale.v1.ExpireNodeResponse
(*RenameNodeResponse)(nil), // 41: headscale.v1.RenameNodeResponse
(*ListNodesResponse)(nil), // 42: headscale.v1.ListNodesResponse
(*BackfillNodeIPsResponse)(nil), // 43: headscale.v1.BackfillNodeIPsResponse
(*CreateApiKeyResponse)(nil), // 44: headscale.v1.CreateApiKeyResponse
(*ExpireApiKeyResponse)(nil), // 45: headscale.v1.ExpireApiKeyResponse
(*ListApiKeysResponse)(nil), // 46: headscale.v1.ListApiKeysResponse
(*DeleteApiKeyResponse)(nil), // 47: headscale.v1.DeleteApiKeyResponse
(*GetPolicyResponse)(nil), // 48: headscale.v1.GetPolicyResponse
(*SetPolicyResponse)(nil), // 49: headscale.v1.SetPolicyResponse
}
var file_headscale_v1_headscale_proto_depIdxs = []int32{
2, // 0: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest
@@ -237,46 +227,40 @@ var file_headscale_v1_headscale_proto_depIdxs = []int32{
17, // 15: headscale.v1.HeadscaleService.RenameNode:input_type -> headscale.v1.RenameNodeRequest
18, // 16: headscale.v1.HeadscaleService.ListNodes:input_type -> headscale.v1.ListNodesRequest
19, // 17: headscale.v1.HeadscaleService.BackfillNodeIPs:input_type -> headscale.v1.BackfillNodeIPsRequest
20, // 18: headscale.v1.HeadscaleService.AuthRegister:input_type -> headscale.v1.AuthRegisterRequest
21, // 19: headscale.v1.HeadscaleService.AuthApprove:input_type -> headscale.v1.AuthApproveRequest
22, // 20: headscale.v1.HeadscaleService.AuthReject:input_type -> headscale.v1.AuthRejectRequest
23, // 21: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest
24, // 22: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest
25, // 23: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest
26, // 24: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest
27, // 25: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest
28, // 26: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest
0, // 27: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest
29, // 28: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse
30, // 29: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse
31, // 30: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse
32, // 31: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse
33, // 32: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse
34, // 33: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse
35, // 34: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse
36, // 35: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse
37, // 36: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse
38, // 37: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse
39, // 38: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse
40, // 39: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse
41, // 40: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse
42, // 41: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse
43, // 42: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse
44, // 43: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse
45, // 44: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse
46, // 45: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse
47, // 46: headscale.v1.HeadscaleService.AuthRegister:output_type -> headscale.v1.AuthRegisterResponse
48, // 47: headscale.v1.HeadscaleService.AuthApprove:output_type -> headscale.v1.AuthApproveResponse
49, // 48: headscale.v1.HeadscaleService.AuthReject:output_type -> headscale.v1.AuthRejectResponse
50, // 49: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse
51, // 50: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse
52, // 51: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse
53, // 52: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse
54, // 53: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse
55, // 54: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse
1, // 55: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse
28, // [28:56] is the sub-list for method output_type
0, // [0:28] is the sub-list for method input_type
20, // 18: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest
21, // 19: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest
22, // 20: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest
23, // 21: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest
24, // 22: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest
25, // 23: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest
0, // 24: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest
26, // 25: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse
27, // 26: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse
28, // 27: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse
29, // 28: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse
30, // 29: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse
31, // 30: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse
32, // 31: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse
33, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse
34, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse
35, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse
36, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse
37, // 36: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse
38, // 37: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse
39, // 38: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse
40, // 39: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse
41, // 40: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse
42, // 41: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse
43, // 42: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse
44, // 43: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse
45, // 44: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse
46, // 45: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse
47, // 46: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse
48, // 47: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse
49, // 48: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse
1, // 49: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse
25, // [25:50] is the sub-list for method output_type
0, // [0:25] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@@ -291,7 +275,6 @@ func file_headscale_v1_headscale_proto_init() {
file_headscale_v1_preauthkey_proto_init()
file_headscale_v1_node_proto_init()
file_headscale_v1_apikey_proto_init()
file_headscale_v1_auth_proto_init()
file_headscale_v1_policy_proto_init()
type x struct{}
out := protoimpl.TypeBuilder{

View File

@@ -709,87 +709,6 @@ func local_request_HeadscaleService_BackfillNodeIPs_0(ctx context.Context, marsh
return msg, metadata, err
}
func request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthRegisterRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if req.Body != nil {
_, _ = io.Copy(io.Discard, req.Body)
}
msg, err := client.AuthRegister(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthRegisterRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.AuthRegister(ctx, &protoReq)
return msg, metadata, err
}
func request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthApproveRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if req.Body != nil {
_, _ = io.Copy(io.Discard, req.Body)
}
msg, err := client.AuthApprove(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthApproveRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.AuthApprove(ctx, &protoReq)
return msg, metadata, err
}
func request_HeadscaleService_AuthReject_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthRejectRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if req.Body != nil {
_, _ = io.Copy(io.Discard, req.Body)
}
msg, err := client.AuthReject(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_HeadscaleService_AuthReject_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthRejectRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.AuthReject(ctx, &protoReq)
return msg, metadata, err
}
func request_HeadscaleService_CreateApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq CreateApiKeyRequest
@@ -1353,66 +1272,6 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser
}
forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthReject_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthReject", runtime.WithHTTPPathPattern("/api/v1/auth/reject"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_HeadscaleService_AuthReject_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthReject_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@@ -1899,57 +1758,6 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser
}
forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthReject_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthReject", runtime.WithHTTPPathPattern("/api/v1/auth/reject"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_HeadscaleService_AuthReject_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthReject_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@@ -2091,9 +1899,6 @@ var (
pattern_HeadscaleService_RenameNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "node", "node_id", "rename", "new_name"}, ""))
pattern_HeadscaleService_ListNodes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "node"}, ""))
pattern_HeadscaleService_BackfillNodeIPs_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "backfillips"}, ""))
pattern_HeadscaleService_AuthRegister_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "register"}, ""))
pattern_HeadscaleService_AuthApprove_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "approve"}, ""))
pattern_HeadscaleService_AuthReject_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "reject"}, ""))
pattern_HeadscaleService_CreateApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, ""))
pattern_HeadscaleService_ExpireApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "apikey", "expire"}, ""))
pattern_HeadscaleService_ListApiKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, ""))
@@ -2122,9 +1927,6 @@ var (
forward_HeadscaleService_RenameNode_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_ListNodes_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_BackfillNodeIPs_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_AuthRegister_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_AuthApprove_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_AuthReject_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_CreateApiKey_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_ExpireApiKey_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_ListApiKeys_0 = runtime.ForwardResponseMessage

View File

@@ -37,9 +37,6 @@ const (
HeadscaleService_RenameNode_FullMethodName = "/headscale.v1.HeadscaleService/RenameNode"
HeadscaleService_ListNodes_FullMethodName = "/headscale.v1.HeadscaleService/ListNodes"
HeadscaleService_BackfillNodeIPs_FullMethodName = "/headscale.v1.HeadscaleService/BackfillNodeIPs"
HeadscaleService_AuthRegister_FullMethodName = "/headscale.v1.HeadscaleService/AuthRegister"
HeadscaleService_AuthApprove_FullMethodName = "/headscale.v1.HeadscaleService/AuthApprove"
HeadscaleService_AuthReject_FullMethodName = "/headscale.v1.HeadscaleService/AuthReject"
HeadscaleService_CreateApiKey_FullMethodName = "/headscale.v1.HeadscaleService/CreateApiKey"
HeadscaleService_ExpireApiKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpireApiKey"
HeadscaleService_ListApiKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListApiKeys"
@@ -74,10 +71,6 @@ type HeadscaleServiceClient interface {
RenameNode(ctx context.Context, in *RenameNodeRequest, opts ...grpc.CallOption) (*RenameNodeResponse, error)
ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error)
BackfillNodeIPs(ctx context.Context, in *BackfillNodeIPsRequest, opts ...grpc.CallOption) (*BackfillNodeIPsResponse, error)
// --- Auth start ---
AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error)
AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error)
AuthReject(ctx context.Context, in *AuthRejectRequest, opts ...grpc.CallOption) (*AuthRejectResponse, error)
// --- ApiKeys start ---
CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error)
ExpireApiKey(ctx context.Context, in *ExpireApiKeyRequest, opts ...grpc.CallOption) (*ExpireApiKeyResponse, error)
@@ -278,36 +271,6 @@ func (c *headscaleServiceClient) BackfillNodeIPs(ctx context.Context, in *Backfi
return out, nil
}
func (c *headscaleServiceClient) AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthRegisterResponse)
err := c.cc.Invoke(ctx, HeadscaleService_AuthRegister_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *headscaleServiceClient) AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthApproveResponse)
err := c.cc.Invoke(ctx, HeadscaleService_AuthApprove_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *headscaleServiceClient) AuthReject(ctx context.Context, in *AuthRejectRequest, opts ...grpc.CallOption) (*AuthRejectResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthRejectResponse)
err := c.cc.Invoke(ctx, HeadscaleService_AuthReject_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *headscaleServiceClient) CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(CreateApiKeyResponse)
@@ -403,10 +366,6 @@ type HeadscaleServiceServer interface {
RenameNode(context.Context, *RenameNodeRequest) (*RenameNodeResponse, error)
ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error)
BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error)
// --- Auth start ---
AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error)
AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error)
AuthReject(context.Context, *AuthRejectRequest) (*AuthRejectResponse, error)
// --- ApiKeys start ---
CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error)
ExpireApiKey(context.Context, *ExpireApiKeyRequest) (*ExpireApiKeyResponse, error)
@@ -481,15 +440,6 @@ func (UnimplementedHeadscaleServiceServer) ListNodes(context.Context, *ListNodes
func (UnimplementedHeadscaleServiceServer) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) {
return nil, status.Error(codes.Unimplemented, "method BackfillNodeIPs not implemented")
}
func (UnimplementedHeadscaleServiceServer) AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AuthRegister not implemented")
}
func (UnimplementedHeadscaleServiceServer) AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AuthApprove not implemented")
}
func (UnimplementedHeadscaleServiceServer) AuthReject(context.Context, *AuthRejectRequest) (*AuthRejectResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AuthReject not implemented")
}
func (UnimplementedHeadscaleServiceServer) CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) {
return nil, status.Error(codes.Unimplemented, "method CreateApiKey not implemented")
}
@@ -856,60 +806,6 @@ func _HeadscaleService_BackfillNodeIPs_Handler(srv interface{}, ctx context.Cont
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_AuthRegister_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthRegisterRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HeadscaleServiceServer).AuthRegister(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: HeadscaleService_AuthRegister_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HeadscaleServiceServer).AuthRegister(ctx, req.(*AuthRegisterRequest))
}
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_AuthApprove_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthApproveRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HeadscaleServiceServer).AuthApprove(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: HeadscaleService_AuthApprove_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HeadscaleServiceServer).AuthApprove(ctx, req.(*AuthApproveRequest))
}
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_AuthReject_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthRejectRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HeadscaleServiceServer).AuthReject(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: HeadscaleService_AuthReject_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HeadscaleServiceServer).AuthReject(ctx, req.(*AuthRejectRequest))
}
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_CreateApiKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CreateApiKeyRequest)
if err := dec(in); err != nil {
@@ -1115,18 +1011,6 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{
MethodName: "BackfillNodeIPs",
Handler: _HeadscaleService_BackfillNodeIPs_Handler,
},
{
MethodName: "AuthRegister",
Handler: _HeadscaleService_AuthRegister_Handler,
},
{
MethodName: "AuthApprove",
Handler: _HeadscaleService_AuthApprove_Handler,
},
{
MethodName: "AuthReject",
Handler: _HeadscaleService_AuthReject_Handler,
},
{
MethodName: "CreateApiKey",
Handler: _HeadscaleService_CreateApiKey_Handler,

View File

@@ -1,44 +0,0 @@
{
"swagger": "2.0",
"info": {
"title": "headscale/v1/auth.proto",
"version": "version not set"
},
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"paths": {},
"definitions": {
"protobufAny": {
"type": "object",
"properties": {
"@type": {
"type": "string"
}
},
"additionalProperties": {}
},
"rpcStatus": {
"type": "object",
"properties": {
"code": {
"type": "integer",
"format": "int32"
},
"message": {
"type": "string"
},
"details": {
"type": "array",
"items": {
"type": "object",
"$ref": "#/definitions/protobufAny"
}
}
}
}
}
}

View File

@@ -138,103 +138,6 @@
]
}
},
"/api/v1/auth/approve": {
"post": {
"operationId": "HeadscaleService_AuthApprove",
"responses": {
"200": {
"description": "A successful response.",
"schema": {
"$ref": "#/definitions/v1AuthApproveResponse"
}
},
"default": {
"description": "An unexpected error response.",
"schema": {
"$ref": "#/definitions/rpcStatus"
}
}
},
"parameters": [
{
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/v1AuthApproveRequest"
}
}
],
"tags": [
"HeadscaleService"
]
}
},
"/api/v1/auth/register": {
"post": {
"summary": "--- Auth start ---",
"operationId": "HeadscaleService_AuthRegister",
"responses": {
"200": {
"description": "A successful response.",
"schema": {
"$ref": "#/definitions/v1AuthRegisterResponse"
}
},
"default": {
"description": "An unexpected error response.",
"schema": {
"$ref": "#/definitions/rpcStatus"
}
}
},
"parameters": [
{
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/v1AuthRegisterRequest"
}
}
],
"tags": [
"HeadscaleService"
]
}
},
"/api/v1/auth/reject": {
"post": {
"operationId": "HeadscaleService_AuthReject",
"responses": {
"200": {
"description": "A successful response.",
"schema": {
"$ref": "#/definitions/v1AuthRejectResponse"
}
},
"default": {
"description": "An unexpected error response.",
"schema": {
"$ref": "#/definitions/rpcStatus"
}
}
},
"parameters": [
{
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/v1AuthRejectRequest"
}
}
],
"tags": [
"HeadscaleService"
]
}
},
"/api/v1/debug/node": {
"post": {
"summary": "--- Node start ---",
@@ -992,47 +895,6 @@
}
}
},
"v1AuthApproveRequest": {
"type": "object",
"properties": {
"authId": {
"type": "string"
}
}
},
"v1AuthApproveResponse": {
"type": "object"
},
"v1AuthRegisterRequest": {
"type": "object",
"properties": {
"user": {
"type": "string"
},
"authId": {
"type": "string"
}
}
},
"v1AuthRegisterResponse": {
"type": "object",
"properties": {
"node": {
"$ref": "#/definitions/v1Node"
}
}
},
"v1AuthRejectRequest": {
"type": "object",
"properties": {
"authId": {
"type": "string"
}
}
},
"v1AuthRejectResponse": {
"type": "object"
},
"v1BackfillNodeIPsResponse": {
"type": "object",
"properties": {

2
go.mod
View File

@@ -14,8 +14,6 @@ require (
github.com/docker/docker v28.5.2+incompatible
github.com/fsnotify/fsnotify v1.9.0
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/metrics v0.1.1
github.com/go-gormigrate/gormigrate/v2 v2.1.5
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e
github.com/gofrs/uuid/v5 v5.4.0

4
go.sum
View File

@@ -181,10 +181,6 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-chi/metrics v0.1.1 h1:CXhbnkAVVjb0k73EBRQ6Z2YdWFnbXZgNtg1Mboguibk=
github.com/go-chi/metrics v0.1.1/go.mod h1:mcGTM1pPalP7WCtb+akNYFO/lwNwBBLCuedepqjoPn4=
github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8=
github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M=
github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY=

View File

@@ -20,9 +20,7 @@ import (
"github.com/cenkalti/backoff/v5"
"github.com/davecgh/go-spew/spew"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
"github.com/gorilla/mux"
grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@@ -459,58 +457,50 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
return os.Remove(h.cfg.UnixSocket)
}
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
r.Use(middleware.Recoverer)
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router := mux.NewRouter()
router.Use(prometheusMiddleware)
r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).
Methods(http.MethodPost, http.MethodGet)
r.Get("/robots.txt", h.RobotsHandler)
r.Get("/health", h.HealthHandler)
r.Get("/version", h.VersionHandler)
r.Get("/key", h.KeyHandler)
r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet)
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).
Methods(http.MethodGet)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
}
r.Get("/apple", h.AppleConfigMessage)
r.Get("/apple/{platform}", h.ApplePlatformConfig)
r.Get("/windows", h.WindowsConfigMessage)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
// TODO(kristoffer): move swagger into a package
r.Get("/swagger", headscale.SwaggerUI)
r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1)
router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
Methods(http.MethodGet)
r.Post("/verify", h.VerifyHandler)
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
if h.cfg.DERP.ServerEnabled {
r.HandleFunc("/derp", h.DERPServer.DERPHandler)
r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
}
r.Route("/api", func(r chi.Router) {
r.Use(h.httpAuthenticationMiddleware)
r.HandleFunc("/v1/*", grpcMux.ServeHTTP)
})
r.Get("/favicon.ico", FaviconHandler)
r.Get("/", BlankHandler)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(h.httpAuthenticationMiddleware)
apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP)
router.HandleFunc("/favicon.ico", FaviconHandler)
router.PathPrefix("/").HandlerFunc(BlankHandler)
return r
return router
}
// Serve launches the HTTP and gRPC server service Headscale and the API.
@@ -1093,52 +1083,3 @@ func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, nil
}
// zerologRequestLogger implements chi's middleware.LogFormatter
// to route HTTP request logs through zerolog.
type zerologRequestLogger struct{}
func (z *zerologRequestLogger) NewLogEntry(
r *http.Request,
) middleware.LogEntry {
return &zerologLogEntry{
method: r.Method,
path: r.URL.Path,
proto: r.Proto,
remote: r.RemoteAddr,
}
}
type zerologLogEntry struct {
method string
path string
proto string
remote string
}
func (e *zerologLogEntry) Write(
status, bytes int,
header http.Header,
elapsed time.Duration,
extra any,
) {
log.Info().
Str("method", e.method).
Str("path", e.path).
Str("proto", e.proto).
Str("remote", e.remote).
Int("status", status).
Int("bytes", bytes).
Dur("elapsed", elapsed).
Msg("http request")
}
func (e *zerologLogEntry) Panic(
v any,
stack []byte,
) {
log.Error().
Interface("panic", v).
Bytes("stack", stack).
Msg("http handler panic")
}

View File

@@ -20,9 +20,7 @@ import (
type AuthProvider interface {
RegisterHandler(w http.ResponseWriter, r *http.Request)
AuthHandler(w http.ResponseWriter, r *http.Request)
RegisterURL(authID types.AuthID) string
AuthURL(authID types.AuthID) string
AuthURL(regID types.RegistrationID) string
}
func (h *Headscale) handleRegister(
@@ -265,24 +263,22 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
}
followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
}
if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok {
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok {
select {
case <-ctx.Done():
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case verdict := <-reg.WaitForAuth():
if verdict.Accept() {
if !verdict.Node.Valid() {
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(verdict.Node), nil
case node := <-reg.Registered:
if node == nil {
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node.View()), nil
}
}
@@ -297,14 +293,14 @@ func (h *Headscale) reqToNewRegisterResponse(
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newAuthID, err := types.NewAuthID()
newRegID, err := types.NewRegistrationID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo.View(),
req.Hostinfo,
machineKey.String(),
req.NodeKey.String(),
)
@@ -313,25 +309,25 @@ func (h *Headscale) reqToNewRegisterResponse(
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
hostinfo.Hostname = hostname
nodeToRegister := types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
}
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Expiry = &req.Expiry
nodeToRegister.Node.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
log.Info().Msgf("new followup node registration using key: %s", newAuthID)
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
log.Info().Msgf("new followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.RegisterURL(newAuthID),
AuthURL: h.authProvider.AuthURL(newRegID),
}, nil
}
@@ -382,6 +378,13 @@ func (h *Headscale) handleRegisterWithAuthKey(
// Send both changes. Empty changes are ignored by Change().
h.Change(changed, routesChange)
// TODO(kradalby): I think this is covered above, but we need to validate that.
// // If policy changed due to node registration, send a separate policy change
// if policyChanged {
// policyChange := change.PolicyChange()
// h.Change(policyChange)
// }
resp := &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
@@ -403,14 +406,14 @@ func (h *Headscale) handleRegisterInteractive(
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
authID, err := types.NewAuthID()
registrationId, err := types.NewRegistrationID()
if err != nil {
return nil, fmt.Errorf("generating registration ID: %w", err)
}
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo.View(),
req.Hostinfo,
machineKey.String(),
req.NodeKey.String(),
)
@@ -433,28 +436,28 @@ func (h *Headscale) handleRegisterInteractive(
hostinfo.Hostname = hostname
nodeToRegister := types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
}
if !req.Expiry.IsZero() {
nodeToRegister.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
h.state.SetAuthCacheEntry(
authID,
authRegReq,
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
},
)
log.Info().Msgf("starting node registration using key: %s", authID)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
}
h.state.SetRegistrationCacheEntry(
registrationId,
nodeToRegister,
)
log.Info().Msgf("starting node registration using key: %s", registrationId)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.RegisterURL(authID),
AuthURL: h.authProvider.AuthURL(registrationId),
}, nil
}

View File

@@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 1: Create user-owned node WITH expiry set
clientExpiry := time.Now().Add(24 * time.Hour)
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "personal-to-tagged",
@@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
},
Expiry: &clientExpiry,
})
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth",
@@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 2: Re-auth with tags (Personal → Tagged conversion)
nodeKey2 := key.NewNode()
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(),
Hostname: "personal-to-tagged",
@@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
},
Expiry: &clientExpiry, // Client still sends expiry
})
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth",
@@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
nodeKey1 := key.NewNode()
// Step 1: Create tagged node (expiry should be nil)
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "tagged-to-personal",
@@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
RequestTags: []string{"tag:server"}, // Tagged node
},
})
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth",
@@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
// Step 2: Re-auth with empty tags (Tagged → Personal conversion)
nodeKey2 := key.NewNode()
clientExpiry := time.Now().Add(48 * time.Hour)
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(),
Hostname: "tagged-to-personal",
@@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
},
Expiry: &clientExpiry, // Client requests expiry
})
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth",

View File

@@ -676,23 +676,28 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_success",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewAuthID()
regID, err := types.NewRegistrationID()
if err != nil {
return "", err
}
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "followup-success-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "followup-success-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Simulate successful registration
// handleRegister will receive the value when it starts waiting
// Simulate successful registration - send to buffered channel
// The channel is buffered (size 1), so this can complete immediately
// and handleRegister will receive the value when it starts waiting
go func() {
user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node")
nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()})
registered <- node
}()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@@ -718,16 +723,20 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_timeout",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewAuthID()
regID, err := types.NewRegistrationID()
if err != nil {
return "", err
}
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "followup-timeout-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Don't call FinishRegistration - will timeout
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "followup-timeout-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Don't send anything on channel - will timeout
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
},
@@ -1336,19 +1345,24 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_node_nil_response",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
regID, err := types.NewAuthID()
regID, err := types.NewRegistrationID()
if err != nil {
return "", err
}
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "nil-response-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "nil-response-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Simulate registration that returns empty NodeView (cache expired during auth)
// Simulate registration that returns nil (cache expired during auth)
// The channel is buffered (size 1), so this can complete immediately
go func() {
nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry
registered <- nil // Nil indicates cache expiry
}()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@@ -1801,7 +1815,7 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
// Generate a registration ID that doesn't exist in cache
// This simulates an expired/missing cache entry
regID, err := types.NewAuthID()
regID, err := types.NewRegistrationID()
if err != nil {
return "", err
}
@@ -1833,11 +1847,11 @@ func TestAuthenticationFlows(t *testing.T) {
// Extract and validate the new registration ID exists in cache
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
newRegID, err := types.AuthIDFromString(newRegIDStr)
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
// Verify new registration entry exists in cache
_, found := app.state.GetAuthCacheEntry(newRegID)
_, found := app.state.GetRegistrationCacheEntry(newRegID)
assert.True(t, found, "new registration should exist in cache")
},
},
@@ -2286,7 +2300,7 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err)
// Verify cache entry exists
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
assert.True(t, found, "registration cache entry should exist initially")
assert.NotNil(t, cacheEntry)
@@ -2301,7 +2315,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
// Cache entry should still exist after auth error (for retry scenarios)
_, stillFound := app.state.GetAuthCacheEntry(registrationID)
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
},
},
@@ -2361,8 +2375,8 @@ func TestAuthenticationFlows(t *testing.T) {
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
// Both cache entries should exist simultaneously
_, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetAuthCacheEntry(regID2)
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
assert.True(t, found1, "first registration cache entry should exist")
assert.True(t, found2, "second registration cache entry should exist")
@@ -2413,8 +2427,8 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err)
// Verify both exist
_, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetAuthCacheEntry(regID2)
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
assert.True(t, found1, "first cache entry should exist")
assert.True(t, found2, "second cache entry should exist")
@@ -2476,7 +2490,7 @@ func TestAuthenticationFlows(t *testing.T) {
}
// First registration should still be in cache (not completed)
_, stillFound := app.state.GetAuthCacheEntry(regID1)
_, stillFound := app.state.GetRegistrationCacheEntry(regID1)
assert.True(t, stillFound, "first registration should still be pending")
},
},
@@ -2587,7 +2601,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
var (
initialResp *tailcfg.RegisterResponse
authURL string
registrationID types.AuthID
registrationID types.RegistrationID
finalResp *tailcfg.RegisterResponse
err error
)
@@ -2615,10 +2629,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
if step.expectCacheEntry {
// Verify registration cache entry was created
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
require.True(t, found, "registration cache entry should exist")
require.NotNil(t, cacheEntry, "cache entry should not be nil")
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key")
}
case stepTypeAuthCompletion:
@@ -2678,7 +2692,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
// Check cache cleanup expectation for this step
if step.expectCacheEntry == false && registrationID != "" {
// Verify cache entry was cleaned up
_, found := app.state.GetAuthCacheEntry(registrationID)
_, found := app.state.GetRegistrationCacheEntry(registrationID)
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
}
}
@@ -2700,7 +2714,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
}
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
// AuthURL format: "http://localhost/register/abc123"
const registerPrefix = "/register/"
@@ -2711,7 +2725,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
idStr := authURL[idx+len(registerPrefix):]
return types.AuthIDFromString(idStr)
return types.RegistrationIDFromString(idStr)
}
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
@@ -2948,7 +2962,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Scenario:
// 1. Node registers with user1 via pre-auth key
// 2. Node logs out (expires)
// 3. Admin runs: headscale auth register --auth-id <id> --user user2
// 3. Admin runs: headscale nodes register --user user2 --key <key>
//
// Expected behavior:
// - User1's original node should STILL EXIST (expired)
@@ -3027,7 +3041,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
require.NotEmpty(t, regID, "Should have valid registration ID")
// Step 4: Admin completes authentication via CLI
// This simulates: headscale auth register --auth-id <id> --user user2
// This simulates: headscale nodes register --user user2 --key <key>
node, _, err := app.state.HandleNodeFromAuthPath(
regID,
types.UserID(user2.ID), // Register to user2, not user1!
@@ -3569,8 +3583,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
nodeKey := key.NewNode()
// Simulate a registration cache entry (as would be created during web auth)
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "webauth-tags-node",
@@ -3579,7 +3593,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
},
})
app.state.SetAuthCacheEntry(registrationID, regEntry)
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
// Complete the web auth - should fail because tag is unauthorized
_, _, err := app.state.HandleNodeFromAuthPath(
@@ -3632,8 +3646,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
nodeKey1 := key.NewNode()
// Step 1: Initial registration with tags
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "reauth-untag-node",
@@ -3642,7 +3656,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{"tag:valid-owned", "tag:second"},
},
})
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
// Complete initial registration with tags
node, _, err := app.state.HandleNodeFromAuthPath(
@@ -3659,8 +3673,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
// Step 2: Reauth with EMPTY tags to untag
nodeKey2 := key.NewNode() // New node key for reauth
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "reauth-untag-node",
@@ -3669,7 +3683,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag
},
})
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
// Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@@ -3745,8 +3759,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
nodeKey2 := key.NewNode() // New node key for reauth
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "authkey-tagged-node",
@@ -3755,7 +3769,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag
},
})
app.state.SetAuthCacheEntry(registrationID, regEntry)
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
// Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@@ -3942,10 +3956,10 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
require.NotNil(t, alice, "Alice user should be created")
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
// This is what happens when running: headscale auth register --auth-id <id> --user alice
// This is what happens when running: headscale nodes register --user alice --key ...
nodeKey2 := key.NewNode()
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(), // Same machine key as the tagged node
NodeKey: nodeKey2.Public(),
Hostname: "tagged-orphan-node",
@@ -3954,7 +3968,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
RequestTags: []string{}, // Empty - transition to user-owned
},
})
app.state.SetAuthCacheEntry(registrationID, regEntry)
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
// This should NOT panic - before the fix, this would panic with:
// panic: runtime error: invalid memory address or nil pointer dereference

View File

@@ -47,7 +47,7 @@ const (
type HSDatabase struct {
DB *gorm.DB
cfg *types.Config
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
}
// NewHeadscaleDatabase creates a new database connection and runs migrations.
@@ -56,7 +56,7 @@ type HSDatabase struct {
//nolint:gocyclo // complex database initialization with many migrations
func NewHeadscaleDatabase(
cfg *types.Config,
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg.Database)
if err != nil {

View File

@@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
}
}
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
}
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {

View File

@@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode(
Str(zf.RegistrationKey, registrationKey).
Msg("registering node")
registrationId, err := types.AuthIDFromString(request.GetKey())
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}
@@ -808,32 +808,33 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: request.GetName(),
}
registrationId, err := types.AuthIDFromString(request.GetKey())
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}
newNode := types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: user,
newNode := types.NewRegisterNode(
types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: user,
Expiry: &time.Time{},
LastSeen: &time.Time{},
Expiry: &time.Time{},
LastSeen: &time.Time{},
Hostinfo: &hostinfo,
}
Hostinfo: &hostinfo,
},
)
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")
authRegReq := types.NewRegisterAuthRequest(newNode)
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
}
func (api headscaleV1APIServer) Health(
@@ -856,59 +857,4 @@ func (api headscaleV1APIServer) Health(
return response, healthErr
}
func (api headscaleV1APIServer) AuthRegister(
ctx context.Context,
request *v1.AuthRegisterRequest,
) (*v1.AuthRegisterResponse, error) {
resp, err := api.RegisterNode(ctx, &v1.RegisterNodeRequest{
Key: request.GetAuthId(),
User: request.GetUser(),
})
if err != nil {
return nil, err
}
return &v1.AuthRegisterResponse{Node: resp.GetNode()}, nil
}
func (api headscaleV1APIServer) AuthApprove(
ctx context.Context,
request *v1.AuthApproveRequest,
) (*v1.AuthApproveResponse, error) {
authID, err := types.AuthIDFromString(request.GetAuthId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err)
}
authReq, ok := api.h.state.GetAuthCacheEntry(authID)
if !ok {
return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID)
}
authReq.FinishAuth(types.AuthVerdict{})
return &v1.AuthApproveResponse{}, nil
}
func (api headscaleV1APIServer) AuthReject(
ctx context.Context,
request *v1.AuthRejectRequest,
) (*v1.AuthRejectResponse, error) {
authID, err := types.AuthIDFromString(request.GetAuthId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err)
}
authReq, ok := api.h.state.GetAuthCacheEntry(authID)
if !ok {
return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID)
}
authReq.FinishAuth(types.AuthVerdict{
Err: fmt.Errorf("auth request rejected"),
})
return &v1.AuthRejectResponse{}, nil
}
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/assets"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@@ -244,58 +245,11 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
}
}
func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string {
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderWeb) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
authID, err := authIDFromRequest(req)
if err != nil {
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write([]byte(templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id "+authID.String(),
).Render()))
if err != nil {
log.Error().Err(err).Msg("failed to write auth response")
}
}
func authIDFromRequest(req *http.Request) (types.AuthID, error) {
raw, err := urlParam[string](req, "auth_id")
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.AuthIDFromString(raw)
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
return registrationId, nil
registrationId.String())
}
// RegisterHandler shows a simple message in the browser to point to the CLI
@@ -307,20 +261,22 @@ func (a *AuthProviderWeb) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
registrationId, err := authIDFromRequest(req)
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
httpError(writer, err)
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write([]byte(templates.AuthWeb(
"Node registration",
"Run the command below in the headscale server to add this node to your network:",
fmt.Sprintf("headscale auth register --auth-id %s --user USERNAME", registrationId.String()),
).Render()))
_, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
if err != nil {
log.Error().Err(err).Msg("failed to write register response")
}

View File

@@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{
}
// emptyCache creates an empty registration cache for testing.
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
}
// Test configuration constants.

View File

@@ -7,15 +7,10 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"tailscale.com/control/controlbase"
@@ -27,15 +22,6 @@ import (
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
// ErrMissingURLParameter is returned when a required URL parameter is not provided.
var ErrMissingURLParameter = errors.New("missing URL parameter")
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
// ErrNoAuthSession is returned when an auth_id does not match any active auth session.
var ErrNoAuthSession = errors.New("no auth session found")
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
@@ -83,7 +69,7 @@ func (h *Headscale) NoiseUpgradeHandler(
return
}
ns := noiseServer{
noiseServer := noiseServer{
headscale: h,
challenge: key.NewChallenge(),
}
@@ -93,89 +79,42 @@ func (h *Headscale) NoiseUpgradeHandler(
writer,
req,
*h.noisePrivateKey,
ns.earlyNoise,
noiseServer.earlyNoise,
)
if err != nil {
httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
return
}
ns.conn = noiseConn
ns.machineKey = ns.conn.Peer()
ns.protocolVersion = ns.conn.ProtocolVersion()
noiseServer.conn = noiseConn
noiseServer.machineKey = noiseServer.conn.Peer()
noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion()
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
router := mux.NewRouter()
router.Use(prometheusMiddleware)
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
r.Use(middleware.Recoverer)
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
Methods(http.MethodPost)
r.Handle("/metrics", metrics.Handler())
// Endpoints outside of the register endpoint must use getAndValidateNode to
// get the node to ensure that the MachineKey matches the Node setting up the
// connection.
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)
r.Route("/machine", func(r chi.Router) {
r.Post("/register", ns.RegistrationHandler)
r.Post("/map", ns.PollNetMapHandler)
// SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected.
r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler)
// Not implemented yet
//
// /whoami is a debug endpoint to validate that the client can communicate over the connection,
// not clear if there is a specific response, it looks like it is just logged.
// https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138
r.Get("/whoami", ns.NotImplementedHandler)
// client sends a [tailcfg.SetDNSRequest] to this endpoints and expect
// the server to create or update this DNS record "somewhere".
// It is typically a TXT record for an ACME challenge.
r.Post("/set-dns", ns.NotImplementedHandler)
// A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes.
// We currently do not support device attributes.
r.Patch("/set-device-attr", ns.NotImplementedHandler)
// A [tailcfg.AuditLogRequest] to send audit log entries to the server.
// The server is expected to store them "somewhere".
// We currently do not support device attributes.
r.Post("/audit-log", ns.NotImplementedHandler)
// handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest].
r.Post("/id-token", ns.NotImplementedHandler)
// Asks the server if a feature is available and receive information about how to enable it.
// Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse].
r.Post("/feature/query", ns.NotImplementedHandler)
r.Post("/update-health", ns.NotImplementedHandler)
r.Route("/webclient", func(r chi.Router) {})
r.Post("/c2n", ns.NotImplementedHandler)
})
ns.httpBaseConfig = &http.Server{
Handler: r,
noiseServer.httpBaseConfig = &http.Server{
Handler: router,
ReadHeaderTimeout: types.HTTPTimeout,
}
ns.http2Server = &http2.Server{}
noiseServer.http2Server = &http2.Server{}
ns.http2Server.ServeConn(
noiseServer.http2Server.ServeConn(
noiseConn,
&http2.ServeConnOpts{
BaseConfig: ns.httpBaseConfig,
BaseConfig: noiseServer.httpBaseConfig,
},
)
}
@@ -250,279 +189,7 @@ func rejectUnsupported(
return false
}
func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) {
d, _ := io.ReadAll(req.Body)
log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit")
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
}
func urlParam[T any](req *http.Request, key string) (T, error) {
var zero T
param := chi.URLParam(req, key)
if param == "" {
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
}
var value T
switch any(value).(type) {
case string:
v, ok := any(param).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
case types.NodeID:
id, err := types.ParseNodeID(param)
if err != nil {
return zero, fmt.Errorf("parsing %s: %w", key, err)
}
v, ok := any(id).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
default:
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
return value, nil
}
// SSHActionHandler handles the /ssh-action endpoint, returning a
// [tailcfg.SSHAction] to the client with the verdict of an SSH access
// request.
func (ns *noiseServer) SSHActionHandler(
writer http.ResponseWriter,
req *http.Request,
) {
srcNodeID, err := urlParam[types.NodeID](req, "src_node_id")
if err != nil {
httpError(writer, NewHTTPError(
http.StatusBadRequest,
"Invalid src_node_id",
err,
))
return
}
dstNodeID, err := urlParam[types.NodeID](req, "dst_node_id")
if err != nil {
httpError(writer, NewHTTPError(
http.StatusBadRequest,
"Invalid dst_node_id",
err,
))
return
}
reqLog := log.With().
Uint64("src_node_id", srcNodeID.Uint64()).
Uint64("dst_node_id", dstNodeID.Uint64()).
Str("ssh_user", req.URL.Query().Get("ssh_user")).
Str("local_user", req.URL.Query().Get("local_user")).
Logger()
reqLog.Trace().Caller().Msg("SSH action request")
action, err := ns.sshAction(
reqLog,
srcNodeID, dstNodeID,
req.URL.Query().Get("auth_id"),
)
if err != nil {
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(action)
if err != nil {
reqLog.Error().Caller().Err(err).
Msg("failed to encode SSH action response")
return
}
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// sshAction resolves the SSH action for the given request parameters.
// It returns the action to send to the client, or an HTTPError on failure.
//
// Three cases:
// 1. Initial request, auto-approved — source recently authenticated
// within the check period, accept immediately.
// 2. Initial request, needs auth — build a HoldAndDelegate URL and
// wait for the user to authenticate.
// 3. Follow-up request — an auth_id is present, wait for the auth
// verdict and accept or reject.
func (ns *noiseServer) sshAction(
reqLog zerolog.Logger,
srcNodeID, dstNodeID types.NodeID,
authIDStr string,
) (*tailcfg.SSHAction, error) {
action := tailcfg.SSHAction{
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
// Look up check params from the server's own policy rather than
// trusting URL parameters, which the client could tamper with.
checkPeriod, checkFound := ns.headscale.state.SSHCheckParams(
srcNodeID, dstNodeID,
)
// Follow-up request with auth_id — wait for the auth verdict.
if authIDStr != "" {
return ns.sshActionFollowUp(
reqLog, &action, authIDStr,
srcNodeID, dstNodeID,
checkFound,
)
}
// Initial request — check if auto-approval applies.
if checkFound && checkPeriod > 0 {
if lastAuth, ok := ns.headscale.state.GetLastSSHAuth(
srcNodeID, dstNodeID,
); ok && time.Since(lastAuth) < checkPeriod {
reqLog.Trace().Caller().
Dur("check_period", checkPeriod).
Time("last_auth", lastAuth).
Msg("auto-approved within check period")
action.Accept = true
return &action, nil
}
}
// No auto-approval — create an auth session and hold.
return ns.sshActionHoldAndDelegate(reqLog, &action)
}
// sshActionHoldAndDelegate creates a new auth session and returns a
// HoldAndDelegate action that directs the client to authenticate.
func (ns *noiseServer) sshActionHoldAndDelegate(
reqLog zerolog.Logger,
action *tailcfg.SSHAction,
) (*tailcfg.SSHAction, error) {
holdURL, err := url.Parse(
ns.headscale.cfg.ServerURL +
"/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID" +
"?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
)
if err != nil {
return nil, NewHTTPError(
http.StatusInternalServerError,
"Internal error",
fmt.Errorf("parsing SSH action URL: %w", err),
)
}
authID, err := types.NewAuthID()
if err != nil {
return nil, NewHTTPError(
http.StatusInternalServerError,
"Internal error",
fmt.Errorf("generating auth ID: %w", err),
)
}
ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest())
authURL := ns.headscale.authProvider.AuthURL(authID)
q := holdURL.Query()
q.Set("auth_id", authID.String())
holdURL.RawQuery = q.Encode()
action.HoldAndDelegate = holdURL.String()
// TODO(kradalby): here we can also send a very tiny mapresponse
// "popping" the url and opening it for the user.
action.Message = fmt.Sprintf(
"# Headscale SSH requires an additional check.\n"+
"# To authenticate, visit: %s\n"+
"# Authentication checked with Headscale SSH.\n",
authURL,
)
reqLog.Info().Caller().
Str("auth_id", authID.String()).
Msg("SSH check pending, waiting for auth")
return action, nil
}
// sshActionFollowUp handles follow-up requests where the client
// provides an auth_id. It blocks until the auth session resolves.
func (ns *noiseServer) sshActionFollowUp(
reqLog zerolog.Logger,
action *tailcfg.SSHAction,
authIDStr string,
srcNodeID, dstNodeID types.NodeID,
checkFound bool,
) (*tailcfg.SSHAction, error) {
authID, err := types.AuthIDFromString(authIDStr)
if err != nil {
return nil, NewHTTPError(
http.StatusBadRequest,
"Invalid auth_id",
fmt.Errorf("parsing auth_id: %w", err),
)
}
reqLog = reqLog.With().Str("auth_id", authID.String()).Logger()
auth, ok := ns.headscale.state.GetAuthCacheEntry(authID)
if !ok {
return nil, NewHTTPError(
http.StatusBadRequest,
"Invalid auth_id",
fmt.Errorf("%w: %s", ErrNoAuthSession, authID),
)
}
reqLog.Trace().Caller().Msg("SSH action follow-up")
verdict := <-auth.WaitForAuth()
if !verdict.Accept() {
action.Reject = true
reqLog.Trace().Caller().Err(verdict.Err).
Msg("authentication rejected")
return action, nil
}
action.Accept = true
// Record the successful auth for future auto-approval.
if checkFound {
ns.headscale.state.SetLastSSHAuth(srcNodeID, dstNodeID)
reqLog.Trace().Caller().
Msg("auth recorded for auto-approval")
}
return action, nil
}
// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
@@ -531,7 +198,7 @@ func (ns *noiseServer) sshActionFollowUp(
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) PollNetMapHandler(
func (ns *noiseServer) NoisePollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@@ -570,8 +237,8 @@ func regErr(err error) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{Error: err.Error()}
}
// RegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) RegistrationHandler(
// NoiseRegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) NoiseRegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@@ -25,8 +26,8 @@ import (
const (
randomByteSize = 16
defaultOAuthOptionsCount = 3
authCacheExpiration = time.Minute * 15
authCacheCleanup = time.Minute * 20
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
)
var (
@@ -43,21 +44,17 @@ var (
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
)
// AuthInfo contains both auth ID and verifier information for OIDC validation.
type AuthInfo struct {
AuthID types.AuthID
Verifier *string
Registration bool
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
RegistrationID types.RegistrationID
Verifier *string
}
type AuthProviderOIDC struct {
h *Headscale
serverURL string
cfg *types.OIDCConfig
// authCache holds auth information between
// the auth and the callback steps.
authCache *zcache.Cache[string, AuthInfo]
h *Headscale
serverURL string
cfg *types.OIDCConfig
registrationCache *zcache.Cache[string, RegistrationInfo]
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@@ -84,63 +81,45 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope,
}
authCache := zcache.New[string, AuthInfo](
authCacheExpiration,
authCacheCleanup,
registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration,
registerCacheCleanup,
)
return &AuthProviderOIDC{
h: h,
serverURL: serverURL,
cfg: cfg,
authCache: authCache,
h: h,
serverURL: serverURL,
cfg: cfg,
registrationCache: registrationCache,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderOIDC) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
a.authHandler(writer, req, false)
}
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
registrationID.String())
}
// RegisterHandler registers the OIDC callback handler with the given router.
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
// Listens in /register/:auth_id.
// Listens in /register/:registration_id.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
a.authHandler(writer, req, true)
}
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
// authHandler takes an incoming request that needs to be authenticated and
// validates and prepares it for the OIDC flow.
func (a *AuthProviderOIDC) authHandler(
writer http.ResponseWriter,
req *http.Request,
registration bool,
) {
authID, err := authIDFromRequest(req)
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
httpError(writer, err)
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
return
}
@@ -158,9 +137,9 @@ func (a *AuthProviderOIDC) authHandler(
return
}
registrationInfo := AuthInfo{
AuthID: authID,
Registration: registration,
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
RegistrationID: registrationId,
}
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@@ -188,7 +167,7 @@ func (a *AuthProviderOIDC) authHandler(
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info
a.authCache.Set(state, registrationInfo)
a.registrationCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
@@ -323,20 +302,16 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
authInfo := a.getAuthInfoFromState(state)
if authInfo == nil {
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
registrationId := a.getRegistrationIDFromState(state)
return
}
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
// If this is a registration flow, then we need to register the node.
if authInfo.Registration {
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
@@ -347,7 +322,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
content := renderRegistrationSuccessTemplate(user, newNode)
if newNode {
verb = "Authenticated"
}
// TODO(kradalby): replace with go-elem
content := renderOIDCCallbackTemplate(user, verb)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
@@ -359,28 +339,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
// If this is not a registration callback, then its a regular authentication callback
// and we need to send a response and confirm that the access was allowed.
authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID)
if !ok {
log.Debug().Caller().Str("auth_id", authInfo.AuthID.String()).Msg("auth session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
// Send a finish auth verdict with no errors to let the CLI know that the authentication was successful.
authReq.FinishAuth(types.AuthVerdict{})
content := renderAuthSuccessTemplate(user)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr
util.LogErr(err, "Failed to write HTTP response")
}
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@@ -413,7 +374,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled {
regInfo, ok := a.authCache.Get(state)
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
@@ -546,14 +507,14 @@ func doOIDCAuthorization(
return nil
}
// getAuthInfoFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
authInfo, ok := a.authCache.Get(state)
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil
}
return &authInfo
return &regInfo.RegistrationID
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
@@ -601,7 +562,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.AuthID,
registrationID types.RegistrationID,
expiry time.Time,
) (bool, error) {
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
@@ -636,38 +597,12 @@ func (a *AuthProviderOIDC) handleRegistration(
return !nodeChange.IsEmpty(), nil
}
func renderRegistrationSuccessTemplate(
func renderOIDCCallbackTemplate(
user *types.User,
newNode bool,
verb string,
) *bytes.Buffer {
result := templates.AuthSuccessResult{
Title: "Headscale - Node Reauthenticated",
Heading: "Node reauthenticated",
Verb: "Reauthenticated",
User: user.Display(),
Message: "You can now close this window.",
}
if newNode {
result.Title = "Headscale - Node Registered"
result.Heading = "Node registered"
result.Verb = "Registered"
}
return bytes.NewBufferString(templates.AuthSuccess(result).Render())
}
func renderAuthSuccessTemplate(
user *types.User,
) *bytes.Buffer {
result := templates.AuthSuccessResult{
Title: "Headscale - SSH Session Authorized",
Heading: "SSH session authorized",
Verb: "Authorized",
User: user.Display(),
Message: "You may return to your terminal.",
}
return bytes.NewBufferString(templates.AuthSuccess(result).Render())
html := templates.OIDCCallback(user.Display(), verb).Render()
return bytes.NewBufferString(html)
}
// getCookieName generates a unique cookie name based on a cookie value.

View File

@@ -7,54 +7,35 @@ import (
"github.com/stretchr/testify/assert"
)
func TestAuthSuccessTemplate(t *testing.T) {
func TestOIDCCallbackTemplate(t *testing.T) {
tests := []struct {
name string
result templates.AuthSuccessResult
name string
userName string
verb string
}{
{
name: "node_registered",
result: templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "newuser@example.com",
Message: "You can now close this window.",
},
name: "logged_in_user",
userName: "test@example.com",
verb: "Logged in",
},
{
name: "node_reauthenticated",
result: templates.AuthSuccessResult{
Title: "Headscale - Node Reauthenticated",
Heading: "Node reauthenticated",
Verb: "Reauthenticated",
User: "test@example.com",
Message: "You can now close this window.",
},
},
{
name: "ssh_session_authorized",
result: templates.AuthSuccessResult{
Title: "Headscale - SSH Session Authorized",
Heading: "SSH session authorized",
Verb: "Authorized",
User: "test@example.com",
Message: "You may return to your terminal.",
},
name: "registered_user",
userName: "newuser@example.com",
verb: "Registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
html := templates.AuthSuccess(tt.result).Render()
// Render using the elem-go template
html := templates.OIDCCallback(tt.userName, tt.verb).Render()
// Verify the HTML contains expected structural elements
// Verify the HTML contains expected elements
assert.Contains(t, html, "<!DOCTYPE html>")
assert.Contains(t, html, "<title>"+tt.result.Title+"</title>")
assert.Contains(t, html, tt.result.Heading)
assert.Contains(t, html, tt.result.Verb+" as ")
assert.Contains(t, html, tt.result.User)
assert.Contains(t, html, tt.result.Message)
assert.Contains(t, html, "<title>Headscale Authentication Succeeded</title>")
assert.Contains(t, html, tt.verb)
assert.Contains(t, html, tt.userName)
assert.Contains(t, html, "You can now close this window")
// Verify Material for MkDocs design system CSS is present
assert.Contains(t, html, "Material for MkDocs")

View File

@@ -2,7 +2,6 @@ package policy
import (
"net/netip"
"time"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
@@ -20,10 +19,7 @@ type PolicyManager interface {
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error)
// SSHCheckParams resolves the SSH check period for a (src, dst) pair
// from the current policy, avoiding trust of client-provided URL params.
SSHCheckParams(srcNodeID, dstNodeID types.NodeID) (time.Duration, bool)
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy(pol []byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)

View File

@@ -1188,9 +1188,8 @@ func TestSSHPolicyRules(t *testing.T) {
"root": "",
},
Action: &tailcfg.SSHAction{
Accept: false,
Accept: true,
SessionDuration: 24 * time.Hour,
HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
@@ -1477,7 +1476,7 @@ func TestSSHPolicyRules(t *testing.T) {
require.NoError(t, err)
got, err := pm.SSHPolicy("unused-url", tt.targetNode.View())
got, err := pm.SSHPolicy(tt.targetNode.View())
require.NoError(t, err)
if diff := cmp.Diff(tt.wantSSH, got); diff != "" {

View File

@@ -319,43 +319,11 @@ func (pol *Policy) compileACLWithAutogroupSelf(
return rules, nil
}
var sshAccept = tailcfg.SSHAction{
Reject: false,
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
// checkPeriodFromRule extracts the check period duration from an SSH rule.
// Returns SSHCheckPeriodDefault if no checkPeriod is configured,
// 0 if checkPeriod is "always", or the configured duration otherwise.
func checkPeriodFromRule(rule SSH) time.Duration {
switch {
case rule.CheckPeriod == nil:
return SSHCheckPeriodDefault
case rule.CheckPeriod.Always:
return 0
default:
return rule.CheckPeriod.Duration
}
}
func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
holdURL := baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER"
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: false,
Accept: false,
SessionDuration: duration,
// Replaced in the client:
// * $SRC_NODE_IP (URL escaped)
// * $SRC_NODE_ID (Node.ID as int64 string)
// * $DST_NODE_IP (URL escaped)
// * $DST_NODE_ID (Node.ID as int64 string)
// * $SSH_USER (URL escaped, ssh user requested)
// * $LOCAL_USER (URL escaped, local user mapped)
HoldAndDelegate: holdURL,
Reject: !accept,
Accept: accept,
SessionDuration: duration,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
@@ -364,7 +332,6 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
//nolint:gocyclo // complex SSH policy compilation logic
func (pol *Policy) compileSSHPolicy(
baseURL string,
users types.Users,
node types.NodeView,
nodes views.Slice[types.NodeView],
@@ -410,9 +377,9 @@ func (pol *Policy) compileSSHPolicy(
switch rule.Action {
case SSHActionAccept:
action = sshAccept
action = sshAction(true, 0)
case SSHActionCheck:
action = sshCheck(baseURL, checkPeriodFromRule(rule))
action = sshAction(true, time.Duration(rule.CheckPeriod))
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}
@@ -536,23 +503,6 @@ func (pol *Policy) compileSSHPolicy(
}
}
// Sort rules: check (HoldAndDelegate) before accept, per Tailscale
// evaluation order (most-restrictive first).
slices.SortStableFunc(rules, func(a, b *tailcfg.SSHRule) int {
aIsCheck := a.Action != nil && a.Action.HoldAndDelegate != ""
bIsCheck := b.Action != nil && b.Action.HoldAndDelegate != ""
if aIsCheck == bIsCheck {
return 0
}
if aIsCheck {
return -1
}
return 1
})
return &tailcfg.SSHPolicy{
Rules: rules,
}, nil

View File

@@ -10,6 +10,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
@@ -614,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
require.NoError(t, err)
// Compile SSH policy
sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice())
require.NoError(t, err)
if tt.wantEmpty {
@@ -679,7 +680,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
SSHs: []SSH{
{
Action: "check",
CheckPeriod: &SSHCheckPeriod{Duration: 24 * time.Hour},
CheckPeriod: model.Duration(24 * time.Hour),
Sources: SSHSrcAliases{gp("group:admins")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{"ssh-it-user"},
@@ -690,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@@ -703,92 +704,9 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
}
assert.Equal(t, expectedUsers, rule.SSHUsers)
// Verify check action: Accept is false, HoldAndDelegate is set
assert.False(t, rule.Action.Accept)
assert.False(t, rule.Action.Reject)
assert.NotEmpty(t, rule.Action.HoldAndDelegate)
assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/")
// Verify check action with session duration
assert.True(t, rule.Action.Accept)
assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration)
// Verify check params are NOT encoded in the URL (looked up server-side).
assert.NotContains(t, rule.Action.HoldAndDelegate, "check_explicit")
assert.NotContains(t, rule.Action.HoldAndDelegate, "check_period")
}
// TestCompileSSHPolicy_CheckBeforeAcceptOrdering verifies that check
// (HoldAndDelegate) rules are sorted before accept rules, even when
// the accept rule appears first in the policy definition.
func TestCompileSSHPolicy_CheckBeforeAcceptOrdering(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},
{Name: "user2", Model: gorm.Model{ID: 2}},
}
nodeTaggedServer := types.Node{
Hostname: "tagged-server",
IPv4: createAddr("100.64.0.1"),
UserID: new(users[0].ID),
User: new(users[0]),
Tags: []string{"tag:server"},
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodes := types.Nodes{&nodeTaggedServer, &nodeUser2}
// Accept rule appears BEFORE check rule in policy definition.
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("user1@")},
},
Groups: Groups{
Group("group:admins"): []Username{Username("user2@")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{gp("group:admins")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{"root"},
},
{
Action: "check",
CheckPeriod: &SSHCheckPeriod{Duration: 24 * time.Hour},
Sources: SSHSrcAliases{gp("group:admins")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{"ssh-it-user"},
},
},
}
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy(
"unused-server-url",
users,
nodeTaggedServer.View(),
nodes.ViewSlice(),
)
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 2)
// First rule must be the check rule (HoldAndDelegate set).
assert.NotEmpty(t, sshPolicy.Rules[0].Action.HoldAndDelegate,
"first rule should be check (HoldAndDelegate)")
assert.False(t, sshPolicy.Rules[0].Action.Accept,
"first rule should not be accept")
// Second rule must be the accept rule.
assert.True(t, sshPolicy.Rules[1].Action.Accept,
"second rule should be accept")
assert.Empty(t, sshPolicy.Rules[1].Action.HoldAndDelegate,
"second rule should not have HoldAndDelegate")
}
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
@@ -838,7 +756,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
require.NoError(t, err)
// Test SSH policy compilation for node2 (owned by user2, who is in the group)
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@@ -888,7 +806,7 @@ func TestSSHJSONSerialization(t *testing.T) {
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
@@ -1495,7 +1413,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for user1's first node
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@@ -1514,7 +1432,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for user2's first node
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy2)
require.Len(t, sshPolicy2.Rules, 1)
@@ -1533,7 +1451,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for tagged node (should have no SSH rules)
node5 := nodes[4].View()
sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy3 != nil {
@@ -1573,7 +1491,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
// For user1's node: should allow SSH from user1's devices
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@@ -1590,7 +1508,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
// For user2's node: should have no rules (user1's devices can't match user2's self)
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@@ -1633,7 +1551,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
// For user1's node: should allow SSH from user1's devices only (not user2's)
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@@ -1650,7 +1568,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
// For user3's node: should have no rules (not in group:admins)
node5 := nodes[4].View()
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@@ -1692,7 +1610,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
// For untagged node: should only get principals from other untagged nodes
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@@ -1710,7 +1628,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
// For tagged node: should get no SSH rules
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@@ -1753,7 +1671,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Test 1: Compile for user1's device (should only match autogroup:self destination)
node1 := nodes[0].View()
sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy1)
require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)")
@@ -1772,7 +1690,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Test 2: Compile for router (should only match tag:router destination)
routerNode := nodes[3].View() // user2-router
sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice())
sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicyRouter)
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
@@ -2170,257 +2088,3 @@ func TestMergeFilterRules(t *testing.T) {
})
}
}
func TestCompileSSHPolicy_CheckPeriodVariants(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},
}
node := types.Node{
Hostname: "device",
IPv4: createAddr("100.64.0.1"),
UserID: new(users[0].ID),
User: new(users[0]),
}
nodes := types.Nodes{&node}
tests := []struct {
name string
checkPeriod *SSHCheckPeriod
wantDuration time.Duration
}{
{
name: "nil period defaults to 12h",
checkPeriod: nil,
wantDuration: SSHCheckPeriodDefault,
},
{
name: "always period uses 0",
checkPeriod: &SSHCheckPeriod{Always: true},
wantDuration: 0,
},
{
name: "explicit 2h",
checkPeriod: &SSHCheckPeriod{Duration: 2 * time.Hour},
wantDuration: 2 * time.Hour,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy := &Policy{
SSHs: []SSH{
{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user1@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: tt.checkPeriod,
},
},
}
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy(
"http://test",
users,
node.View(),
nodes.ViewSlice(),
)
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
rule := sshPolicy.Rules[0]
assert.Equal(t, tt.wantDuration, rule.Action.SessionDuration)
// Check params must NOT be in the URL; they are
// resolved server-side via SSHCheckParams.
assert.NotContains(t, rule.Action.HoldAndDelegate, "check_explicit")
assert.NotContains(t, rule.Action.HoldAndDelegate, "check_period")
})
}
}
func TestSSHCheckParams(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},
{Name: "user2", Model: gorm.Model{ID: 2}},
}
nodeUser1 := types.Node{
ID: 1,
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: new(users[0].ID),
User: new(users[0]),
}
nodeUser2 := types.Node{
ID: 2,
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodeTaggedServer := types.Node{
ID: 3,
Hostname: "tagged-server",
IPv4: createAddr("100.64.0.3"),
UserID: new(users[0].ID),
User: new(users[0]),
Tags: []string{"tag:server"},
}
nodes := types.Nodes{&nodeUser1, &nodeUser2, &nodeTaggedServer}
tests := []struct {
name string
policy []byte
srcID types.NodeID
dstID types.NodeID
wantPeriod time.Duration
wantOK bool
}{
{
name: "explicit check period for tagged destination",
policy: []byte(`{
"tagOwners": {"tag:server": ["user1@"]},
"ssh": [{
"action": "check",
"checkPeriod": "2h",
"src": ["user2@"],
"dst": ["tag:server"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(2),
dstID: types.NodeID(3),
wantPeriod: 2 * time.Hour,
wantOK: true,
},
{
name: "default period when checkPeriod omitted",
policy: []byte(`{
"tagOwners": {"tag:server": ["user1@"]},
"ssh": [{
"action": "check",
"src": ["user2@"],
"dst": ["tag:server"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(2),
dstID: types.NodeID(3),
wantPeriod: SSHCheckPeriodDefault,
wantOK: true,
},
{
name: "always check (checkPeriod always)",
policy: []byte(`{
"tagOwners": {"tag:server": ["user1@"]},
"ssh": [{
"action": "check",
"checkPeriod": "always",
"src": ["user2@"],
"dst": ["tag:server"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(2),
dstID: types.NodeID(3),
wantPeriod: 0,
wantOK: true,
},
{
name: "no match when src not in rule",
policy: []byte(`{
"tagOwners": {"tag:server": ["user1@"]},
"ssh": [{
"action": "check",
"src": ["user1@"],
"dst": ["tag:server"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(2),
dstID: types.NodeID(3),
wantOK: false,
},
{
name: "no match when dst not in rule",
policy: []byte(`{
"tagOwners": {"tag:server": ["user1@"]},
"ssh": [{
"action": "check",
"src": ["user2@"],
"dst": ["tag:server"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(2),
dstID: types.NodeID(1),
wantOK: false,
},
{
name: "accept rule is not returned",
policy: []byte(`{
"tagOwners": {"tag:server": ["user1@"]},
"ssh": [{
"action": "accept",
"src": ["user2@"],
"dst": ["tag:server"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(2),
dstID: types.NodeID(3),
wantOK: false,
},
{
name: "autogroup:self matches same-user pair",
policy: []byte(`{
"ssh": [{
"action": "check",
"checkPeriod": "6h",
"src": ["user1@"],
"dst": ["autogroup:self"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(1),
dstID: types.NodeID(1),
wantPeriod: 6 * time.Hour,
wantOK: true,
},
{
name: "autogroup:self rejects cross-user pair",
policy: []byte(`{
"ssh": [{
"action": "check",
"src": ["user1@"],
"dst": ["autogroup:self"],
"users": ["autogroup:nonroot"]
}]
}`),
srcID: types.NodeID(1),
dstID: types.NodeID(2),
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager(tt.policy, users, nodes.ViewSlice())
require.NoError(t, err)
period, ok := pm.SSHCheckParams(tt.srcID, tt.dstID)
assert.Equal(t, tt.wantOK, ok, "ok mismatch")
if tt.wantOK {
assert.Equal(t, tt.wantPeriod, period, "period mismatch")
}
})
}
}

View File

@@ -9,7 +9,6 @@ import (
"slices"
"strings"
"sync"
"time"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/policy/policyutil"
@@ -223,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil
}
func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) {
func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -231,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcf
return sshPol, nil
}
sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes)
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}
@@ -241,84 +240,6 @@ func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcf
return sshPol, nil
}
// SSHCheckParams resolves the SSH check period for a source-destination
// node pair by looking up the current policy. This avoids trusting URL
// parameters that a client could tamper with.
// It returns the check period duration and whether a matching check
// rule was found.
func (pm *PolicyManager) SSHCheckParams(
srcNodeID, dstNodeID types.NodeID,
) (time.Duration, bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.pol == nil || len(pm.pol.SSHs) == 0 {
return 0, false
}
// Find the source and destination node views.
var srcNode, dstNode types.NodeView
for _, n := range pm.nodes.All() {
nid := n.ID()
if nid == srcNodeID {
srcNode = n
}
if nid == dstNodeID {
dstNode = n
}
if srcNode.Valid() && dstNode.Valid() {
break
}
}
if !srcNode.Valid() || !dstNode.Valid() {
return 0, false
}
// Iterate SSH rules to find the first matching check rule.
for _, rule := range pm.pol.SSHs {
if rule.Action != SSHActionCheck {
continue
}
// Resolve sources and check if src node matches.
srcIPs, err := rule.Sources.Resolve(pm.pol, pm.users, pm.nodes)
if err != nil || srcIPs == nil {
continue
}
if !slices.ContainsFunc(srcNode.IPs(), srcIPs.Contains) {
continue
}
// Check if dst node matches any destination.
for _, dst := range rule.Destinations {
if ag, isAG := dst.(*AutoGroup); isAG && ag.Is(AutoGroupSelf) {
if !srcNode.IsTagged() && !dstNode.IsTagged() &&
srcNode.User().ID() == dstNode.User().ID() {
return checkPeriodFromRule(rule), true
}
continue
}
dstIPs, err := dst.Resolve(pm.pol, pm.users, pm.nodes)
if err != nil || dstIPs == nil {
continue
}
if slices.ContainsFunc(dstNode.IPs(), dstIPs.Contains) {
return checkPeriodFromRule(rule), true
}
}
}
return 0, false
}
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
if len(polB) == 0 {
return false, nil

View File

@@ -7,7 +7,6 @@ import (
"slices"
"strconv"
"strings"
"time"
"github.com/go-json-experiment/json"
"github.com/juanfont/headscale/hscontrol/types"
@@ -44,17 +43,6 @@ var (
ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged")
ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)")
ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination")
ErrSSHCheckPeriodBelowMin = errors.New("checkPeriod below minimum of 1 minute")
ErrSSHCheckPeriodAboveMax = errors.New("checkPeriod above maximum of 168 hours (1 week)")
ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"")
)
// SSH check period constants per Tailscale docs:
// https://tailscale.com/kb/1193/tailscale-ssh
const (
SSHCheckPeriodDefault = 12 * time.Hour
SSHCheckPeriodMin = time.Minute
SSHCheckPeriodMax = 168 * time.Hour
)
// ACL validation errors.
@@ -2031,19 +2019,6 @@ func (p *Policy) validate() error {
if err != nil {
errs = append(errs, err)
}
// Validate checkPeriod
if ssh.CheckPeriod != nil {
switch {
case ssh.Action != SSHActionCheck:
errs = append(errs, ErrSSHCheckPeriodOnNonCheck)
default:
err := ssh.CheckPeriod.Validate()
if err != nil {
errs = append(errs, err)
}
}
}
}
for _, tagOwners := range p.TagOwners {
@@ -2122,75 +2097,13 @@ func (p *Policy) validate() error {
return nil
}
// SSHCheckPeriod represents the check period for SSH "check" mode rules.
// nil means not specified (runtime default of 12h applies).
// Always=true means "always" (check on every request).
// Duration is an explicit period (min 1m, max 168h).
type SSHCheckPeriod struct {
Always bool
Duration time.Duration
}
// UnmarshalJSON implements JSON unmarshaling for SSHCheckPeriod.
func (p *SSHCheckPeriod) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
if str == "always" {
p.Always = true
return nil
}
d, err := model.ParseDuration(str)
if err != nil {
return fmt.Errorf("parsing checkPeriod %q: %w", str, err)
}
p.Duration = time.Duration(d)
return nil
}
// MarshalJSON implements JSON marshaling for SSHCheckPeriod.
func (p SSHCheckPeriod) MarshalJSON() ([]byte, error) {
if p.Always {
return []byte(`"always"`), nil
}
return fmt.Appendf(nil, "%q", p.Duration.String()), nil
}
// Validate checks that the SSHCheckPeriod is within allowed bounds.
func (p *SSHCheckPeriod) Validate() error {
if p.Always {
return nil
}
if p.Duration < SSHCheckPeriodMin {
return fmt.Errorf(
"%w: got %s",
ErrSSHCheckPeriodBelowMin,
p.Duration,
)
}
if p.Duration > SSHCheckPeriodMax {
return fmt.Errorf(
"%w: got %s",
ErrSSHCheckPeriodAboveMax,
p.Duration,
)
}
return nil
}
// SSH controls who can ssh into which machines.
type SSH struct {
Action SSHAction `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users SSHUsers `json:"users"`
CheckPeriod *SSHCheckPeriod `json:"checkPeriod,omitempty"`
Action SSHAction `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users SSHUsers `json:"users"`
CheckPeriod model.Duration `json:"checkPeriod,omitempty"`
}
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.

View File

@@ -11,6 +11,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
@@ -710,7 +711,7 @@ func TestUnmarshalPolicy(t *testing.T) {
},
"ssh": [
{
"action": "check",
"action": "accept",
"src": [
"group:admins"
],
@@ -729,7 +730,7 @@ func TestUnmarshalPolicy(t *testing.T) {
},
SSHs: []SSH{
{
Action: "check",
Action: "accept",
Sources: SSHSrcAliases{
gp("group:admins"),
},
@@ -739,7 +740,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Users: []SSHUser{
SSHUser("root"),
},
CheckPeriod: &SSHCheckPeriod{Duration: 24 * time.Hour},
CheckPeriod: model.Duration(24 * time.Hour),
},
},
},
@@ -3826,218 +3827,3 @@ func TestFlattenTagOwners(t *testing.T) {
})
}
}
func TestSSHCheckPeriodUnmarshal(t *testing.T) {
tests := []struct {
name string
input string
want *SSHCheckPeriod
wantErr bool
}{
{
name: "always",
input: `"always"`,
want: &SSHCheckPeriod{Always: true},
},
{
name: "1h",
input: `"1h"`,
want: &SSHCheckPeriod{Duration: time.Hour},
},
{
name: "30m",
input: `"30m"`,
want: &SSHCheckPeriod{Duration: 30 * time.Minute},
},
{
name: "168h",
input: `"168h"`,
want: &SSHCheckPeriod{Duration: 168 * time.Hour},
},
{
name: "invalid",
input: `"notaduration"`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got SSHCheckPeriod
err := json.Unmarshal([]byte(tt.input), &got)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, *tt.want, got)
})
}
}
func TestSSHCheckPeriodRoundTrip(t *testing.T) {
tests := []struct {
name string
input SSHCheckPeriod
}{
{
name: "always",
input: SSHCheckPeriod{Always: true},
},
{
name: "2h",
input: SSHCheckPeriod{Duration: 2 * time.Hour},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.input)
require.NoError(t, err)
var got SSHCheckPeriod
err = json.Unmarshal(data, &got)
require.NoError(t, err)
assert.Equal(t, tt.input, got)
})
}
}
func TestSSHCheckPeriodNilInSSH(t *testing.T) {
input := `{
"action": "check",
"src": ["user@"],
"dst": ["autogroup:member"],
"users": ["root"]
}`
var ssh SSH
err := json.Unmarshal([]byte(input), &ssh)
require.NoError(t, err)
assert.Nil(t, ssh.CheckPeriod)
}
func TestSSHCheckPeriodValidate(t *testing.T) {
tests := []struct {
name string
period SSHCheckPeriod
wantErr error
}{
{
name: "always is valid",
period: SSHCheckPeriod{Always: true},
},
{
name: "1m minimum valid",
period: SSHCheckPeriod{Duration: time.Minute},
},
{
name: "168h maximum valid",
period: SSHCheckPeriod{Duration: 168 * time.Hour},
},
{
name: "30s below minimum",
period: SSHCheckPeriod{Duration: 30 * time.Second},
wantErr: ErrSSHCheckPeriodBelowMin,
},
{
name: "169h above maximum",
period: SSHCheckPeriod{Duration: 169 * time.Hour},
wantErr: ErrSSHCheckPeriodAboveMax,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.period.Validate()
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}
func TestSSHCheckPeriodPolicyValidation(t *testing.T) {
tests := []struct {
name string
ssh SSH
wantErr error
}{
{
name: "check with nil period is valid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
},
},
{
name: "check with always is valid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Always: true},
},
},
{
name: "check with 1h is valid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: time.Hour},
},
},
{
name: "accept with checkPeriod is invalid",
ssh: SSH{
Action: SSHActionAccept,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: time.Hour},
},
wantErr: ErrSSHCheckPeriodOnNonCheck,
},
{
name: "check with 30s is invalid",
ssh: SSH{
Action: SSHActionCheck,
Sources: SSHSrcAliases{up("user@")},
Destinations: SSHDstAliases{agp("autogroup:member")},
Users: SSHUsers{"root"},
CheckPeriod: &SSHCheckPeriod{Duration: 30 * time.Second},
},
wantErr: ErrSSHCheckPeriodBelowMin,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol := &Policy{SSHs: []SSH{tt.ssh}}
err := pol.validate()
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
})
}
}

View File

@@ -1,103 +0,0 @@
package state
import (
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestStateForSSHCheck() *State {
return &State{
sshCheckAuth: make(map[sshCheckPair]time.Time),
}
}
func TestSSHCheckAuth(t *testing.T) {
s := newTestStateForSSHCheck()
src := types.NodeID(1)
dst := types.NodeID(2)
otherDst := types.NodeID(3)
otherSrc := types.NodeID(4)
// No record initially
_, ok := s.GetLastSSHAuth(src, dst)
require.False(t, ok)
// Record auth for (src, dst)
s.SetLastSSHAuth(src, dst)
// Same src+dst: found
authTime, ok := s.GetLastSSHAuth(src, dst)
require.True(t, ok)
assert.WithinDuration(t, time.Now(), authTime, time.Second)
// Same src, different dst: not found (auth is per-pair)
_, ok = s.GetLastSSHAuth(src, otherDst)
require.False(t, ok)
// Different src: not found
_, ok = s.GetLastSSHAuth(otherSrc, dst)
require.False(t, ok)
}
func TestSSHCheckAuthClear(t *testing.T) {
s := newTestStateForSSHCheck()
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(2))
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(3))
_, ok := s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
require.True(t, ok)
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
require.True(t, ok)
// Clear
s.ClearSSHCheckAuth()
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
require.False(t, ok)
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
require.False(t, ok)
}
func TestSSHCheckAuthConcurrent(t *testing.T) {
s := newTestStateForSSHCheck()
var wg sync.WaitGroup
for i := range 100 {
wg.Go(func() {
src := types.NodeID(uint64(i % 10)) //nolint:gosec
dst := types.NodeID(uint64(i%5 + 10)) //nolint:gosec
s.SetLastSSHAuth(src, dst)
s.GetLastSSHAuth(src, dst)
})
}
wg.Wait()
// Clear concurrently with reads
wg.Add(2)
go func() {
defer wg.Done()
s.ClearSSHCheckAuth()
}()
go func() {
defer wg.Done()
s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
}()
wg.Wait()
}

View File

@@ -64,16 +64,6 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
// ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique")
// ErrRegistrationExpired is returned when a registration has expired.
var ErrRegistrationExpired = errors.New("registration expired")
// sshCheckPair identifies a (source, destination) node pair for
// SSH check auth tracking.
type sshCheckPair struct {
Src types.NodeID
Dst types.NodeID
}
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
@@ -92,31 +82,10 @@ type State struct {
derpMap atomic.Pointer[tailcfg.DERPMap]
// polMan handles policy evaluation and management
polMan policy.PolicyManager
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
// registrationCache caches node registration data to reduce database load
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
// sshCheckAuth tracks when source nodes last completed SSH check auth.
//
// For rules without explicit checkPeriod (default 12h), auth covers any
// destination — keyed by (src, Dst=0) where 0 is a sentinel meaning "any".
// Ref: "Once re-authenticated to a destination, the user can access the
// device and any other device in the tailnet without re-verification
// for the next 12 hours." — https://tailscale.com/kb/1193/tailscale-ssh
//
// For rules with explicit checkPeriod, auth covers only that specific
// destination — keyed by (src, dst).
// Ref: "If a different check period is specified for the connection,
// then the user can access specifically this device without
// re-verification for the duration of the check period."
//
// Ref: https://github.com/tailscale/tailscale/issues/10480
// Ref: https://github.com/tailscale/tailscale/issues/7125
sshCheckAuth map[sshCheckPair]time.Time
sshCheckMu sync.RWMutex
}
// NewState creates and initializes a new State instance, setting up the database,
@@ -132,20 +101,20 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
}
authCache := zcache.New[types.AuthID, types.AuthRequest](
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
cacheExpiration,
cacheCleanup,
)
authCache.OnEvicted(
func(id types.AuthID, rn types.AuthRequest) {
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
},
)
db, err := hsdb.NewHeadscaleDatabase(
cfg,
authCache,
registrationCache,
)
if err != nil {
return nil, fmt.Errorf("initializing database: %w", err)
@@ -209,14 +178,12 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{
cfg: cfg,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
authCache: authCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
sshCheckAuth: make(map[sshCheckPair]time.Time),
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
registrationCache: registrationCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
}, nil
}
@@ -255,10 +222,6 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
return nil, fmt.Errorf("setting policy: %w", err)
}
// Clear SSH check auth times when policy changes to ensure stale
// approvals don't persist if checkPeriod rules are modified or removed.
s.ClearSSHCheckAuth()
// Rebuild peer maps after policy changes because the peersFunc in NodeStore
// uses the PolicyManager's filters. Without this, nodes won't see newly allowed
// peers until a node is added/removed, causing autogroup:self policies to not
@@ -903,15 +866,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
// SSHPolicy returns the SSH access policy for a node.
func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
return s.polMan.SSHPolicy(s.cfg.ServerURL, node)
}
// SSHCheckParams resolves the SSH check period for a source-destination
// node pair from the current policy.
func (s *State) SSHCheckParams(
srcNodeID, dstNodeID types.NodeID,
) (time.Duration, bool) {
return s.polMan.SSHCheckParams(srcNodeID, dstNodeID)
return s.polMan.SSHPolicy(node)
}
// Filter returns the current network filter rules and matches.
@@ -936,15 +891,7 @@ func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool {
// SetPolicy updates the policy configuration.
func (s *State) SetPolicy(pol []byte) (bool, error) {
changed, err := s.polMan.SetPolicy(pol)
if err != nil {
return changed, err
}
// Clear SSH check auth times when policy changes.
s.ClearSSHCheckAuth()
return changed, nil
return s.polMan.SetPolicy(pol)
}
// AutoApproveRoutes checks if a node's routes should be auto-approved.
@@ -1110,9 +1057,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id)
}
// GetAuthCacheEntry retrieves a node registration from cache.
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
entry, found := s.authCache.Get(id)
// GetRegistrationCacheEntry retrieves a node registration from cache.
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
entry, found := s.registrationCache.Get(id)
if !found {
return nil, false
}
@@ -1120,53 +1067,26 @@ func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
return &entry, true
}
// SetAuthCacheEntry stores a node registration in cache.
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
s.authCache.Set(id, entry)
}
// SetLastSSHAuth records a successful SSH check authentication
// for the given (src, dst) node pair.
func (s *State) SetLastSSHAuth(src, dst types.NodeID) {
s.sshCheckMu.Lock()
defer s.sshCheckMu.Unlock()
s.sshCheckAuth[sshCheckPair{Src: src, Dst: dst}] = time.Now()
}
// GetLastSSHAuth returns when src last authenticated for SSH check
// to dst.
func (s *State) GetLastSSHAuth(src, dst types.NodeID) (time.Time, bool) {
s.sshCheckMu.RLock()
defer s.sshCheckMu.RUnlock()
t, ok := s.sshCheckAuth[sshCheckPair{Src: src, Dst: dst}]
return t, ok
}
// ClearSSHCheckAuth clears all recorded SSH check auth times.
// Called when the policy changes to ensure stale auth times don't grant access.
func (s *State) ClearSSHCheckAuth() {
s.sshCheckMu.Lock()
defer s.sshCheckMu.Unlock()
s.sshCheckAuth = make(map[sshCheckPair]time.Time)
// SetRegistrationCacheEntry stores a node registration in cache.
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
s.registrationCache.Set(id, entry)
}
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
if !nv.Hostinfo().Valid() {
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
if hostinfo == nil {
log.Warn().
Caller().
EmbedObject(nv).
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had nil hostinfo, generated default hostname")
} else if nv.Hostinfo().Hostname() == "" {
} else if hostinfo.Hostname == "" {
log.Warn().
Caller().
EmbedObject(nv).
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had empty hostname, generated default")
@@ -1208,7 +1128,7 @@ type authNodeUpdateParams struct {
// Node to update; must be valid and in NodeStore.
ExistingNode types.NodeView
// Client data: keys, hostinfo, endpoints.
RegEntry *types.AuthRequest
RegEntry *types.RegisterNode
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
ValidHostinfo *tailcfg.Hostinfo
// Hostname from hostinfo, or generated from keys if client omits it.
@@ -1227,7 +1147,6 @@ type authNodeUpdateParams struct {
// an existing node. It updates the node in NodeStore, processes RequestTags, and
// persists changes to the database.
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
regNv := params.RegEntry.Node()
// Log the operation type
if params.IsConvertFromTag {
log.Info().
@@ -1236,16 +1155,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
Msg("Converting tagged node to user-owned node")
} else {
log.Info().
Object("existing", params.ExistingNode).
Object("incoming", regNv).
EmbedObject(params.ExistingNode).
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
Msg("Updating existing node registration via reauth")
}
// Process RequestTags during reauth (#2979)
// Due to json:",omitempty", we treat empty/nil as "clear tags"
var requestTags []string
if regNv.Hostinfo().Valid() {
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
if params.RegEntry.Node.Hostinfo != nil {
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
}
oldTags := params.ExistingNode.Tags().AsSlice()
@@ -1263,8 +1182,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
// Update existing node in NodeStore - validation passed, safe to mutate
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
node.NodeKey = regNv.NodeKey()
node.DiscoKey = regNv.DiscoKey()
node.NodeKey = params.RegEntry.Node.NodeKey
node.DiscoKey = params.RegEntry.Node.DiscoKey
node.Hostname = params.Hostname
// Preserve NetInfo from existing node when re-registering
@@ -1275,7 +1194,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
params.ValidHostinfo,
)
node.Endpoints = regNv.Endpoints().AsSlice()
node.Endpoints = params.RegEntry.Node.Endpoints
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
@@ -1284,7 +1203,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.IsConvertFromTag {
node.RegisterMethod = params.RegisterMethod
} else {
node.RegisterMethod = regNv.RegisterMethod()
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
}
// Track tagged status BEFORE processing tags
@@ -1304,7 +1223,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = regNv.Expiry().Clone()
node.Expiry = params.RegEntry.Node.Expiry
}
case !wasTagged && isTagged:
// Personal → Tagged: clear expiry (tagged nodes don't expire)
@@ -1314,14 +1233,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = regNv.Expiry().Clone()
node.Expiry = params.RegEntry.Node.Expiry
}
case !isTagged:
// Personal → Personal: update expiry from client
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = regNv.Expiry().Clone()
node.Expiry = params.RegEntry.Node.Expiry
}
}
// Tagged → Tagged: keep existing expiry (nil) - no action needed
@@ -1608,13 +1527,13 @@ func (s *State) processReauthTags(
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
func (s *State) HandleNodeFromAuthPath(
authID types.AuthID,
registrationID types.RegistrationID,
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (types.NodeView, change.Change, error) {
// Get the registration entry from cache
regEntry, ok := s.GetAuthCacheEntry(authID)
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
if !ok {
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
}
@@ -1627,27 +1546,25 @@ func (s *State) HandleNodeFromAuthPath(
// Ensure we have a valid hostname from the registration cache entry
hostname := util.EnsureHostname(
regEntry.Node().Hostinfo(),
regEntry.Node().MachineKey().String(),
regEntry.Node().NodeKey().String(),
regEntry.Node.Hostinfo,
regEntry.Node.MachineKey.String(),
regEntry.Node.NodeKey.String(),
)
// Ensure we have valid hostinfo
hostinfo := &tailcfg.Hostinfo{}
if regEntry.Node().Hostinfo().Valid() {
hostinfo = regEntry.Node().Hostinfo().AsStruct()
}
hostinfo.Hostname = hostname
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
regEntry.Node(),
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
user.Name,
hostname,
regEntry.Node.Hostinfo,
)
// Lookup existing nodes
machineKey := regEntry.Node().MachineKey()
machineKey := regEntry.Node.MachineKey
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
@@ -1661,7 +1578,7 @@ func (s *State) HandleNodeFromAuthPath(
// Create logger with common fields for all auth operations
logger := log.With().
Str(zf.RegistrationID, authID.String()).
Str(zf.RegistrationID, registrationID.String()).
Str(zf.UserName, user.Name).
Str(zf.MachineKey, machineKey.ShortString()).
Str(zf.Method, registrationMethod).
@@ -1670,7 +1587,7 @@ func (s *State) HandleNodeFromAuthPath(
// Common params for update operations
updateParams := authNodeUpdateParams{
RegEntry: regEntry,
ValidHostinfo: hostinfo,
ValidHostinfo: validHostinfo,
Hostname: hostname,
User: user,
Expiry: expiry,
@@ -1704,7 +1621,7 @@ func (s *State) HandleNodeFromAuthPath(
Msg("Creating new node for different user (same machine key exists for another user)")
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, hostinfo,
logger, user, regEntry, hostname, validHostinfo,
expiry, registrationMethod, existingNodeAnyUser,
)
if err != nil {
@@ -1712,7 +1629,7 @@ func (s *State) HandleNodeFromAuthPath(
}
} else {
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, hostinfo,
logger, user, regEntry, hostname, validHostinfo,
expiry, registrationMethod, types.NodeView{},
)
if err != nil {
@@ -1721,10 +1638,10 @@ func (s *State) HandleNodeFromAuthPath(
}
// Signal to waiting clients
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
regEntry.SendAndClose(finalNode.AsStruct())
// Delete from registration cache
s.authCache.Delete(authID)
s.registrationCache.Delete(registrationID)
// Update policy managers
usersChange, err := s.updatePolicyManagerUsers()
@@ -1753,7 +1670,7 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) createNewNodeFromAuth(
logger zerolog.Logger,
user *types.User,
regEntry *types.AuthRequest,
regEntry *types.RegisterNode,
hostname string,
validHostinfo *tailcfg.Hostinfo,
expiry *time.Time,
@@ -1766,13 +1683,13 @@ func (s *State) createNewNodeFromAuth(
return s.createAndSaveNewNode(newNodeParams{
User: *user,
MachineKey: regEntry.Node().MachineKey(),
NodeKey: regEntry.Node().NodeKey(),
DiscoKey: regEntry.Node().DiscoKey(),
MachineKey: regEntry.Node.MachineKey,
NodeKey: regEntry.Node.NodeKey,
DiscoKey: regEntry.Node.DiscoKey,
Hostname: hostname,
Hostinfo: validHostinfo,
Endpoints: regEntry.Node().Endpoints().AsSlice(),
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
Endpoints: regEntry.Node.Endpoints,
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
RegisterMethod: registrationMethod,
ExistingNodeForNetinfo: existingNodeForNetinfo,
})
@@ -1867,7 +1784,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Ensure we have a valid hostname - handle nil/empty cases
hostname := util.EnsureHostname(
regReq.Hostinfo.View(),
regReq.Hostinfo,
machineKey.String(),
regReq.NodeKey.String(),
)
@@ -1876,6 +1793,14 @@ func (s *State) HandleNodeFromPreAuthKey(
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
machineKey.ShortString(),
regReq.NodeKey.ShortString(),
pakUsername(),
hostname,
regReq.Hostinfo,
)
log.Debug().
Caller().
Str(zf.NodeName, hostname).

View File

@@ -1,62 +0,0 @@
package templates
import (
"github.com/chasefleming/elem-go"
)
// AuthSuccessResult contains the text content for an authentication success page.
// Each field controls a distinct piece of user-facing text so that every auth
// flow (node registration, reauthentication, SSH check, …) can clearly
// communicate what just happened.
type AuthSuccessResult struct {
// Title is the browser tab / page title,
// e.g. "Headscale - Node Registered".
Title string
// Heading is the bold green text inside the success box,
// e.g. "Node registered".
Heading string
// Verb is the action prefix in the body text before "as <user>",
// e.g. "Registered", "Reauthenticated", "Authorized".
Verb string
// User is the display name shown in bold in the body text,
// e.g. "user@example.com".
User string
// Message is the follow-up instruction shown after the user name,
// e.g. "You can now close this window."
Message string
}
// AuthSuccess renders an authentication / authorisation success page.
// The caller controls every user-visible string via [AuthSuccessResult] so the
// page clearly describes what succeeded (registration, reauth, SSH check, …).
func AuthSuccess(result AuthSuccessResult) *elem.Element {
box := successBox(
result.Heading,
elem.Text(result.Verb+" as "),
elem.Strong(nil, elem.Text(result.User)),
elem.Text(". "+result.Message),
)
return HtmlStructure(
elem.Title(nil, elem.Text(result.Title)),
mdTypesetBody(
headscaleLogo(),
box,
H2(elem.Text("Getting started")),
P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")),
Ul(
elem.Li(nil,
externalLink("https://headscale.net/stable/", "Headscale documentation"),
),
elem.Li(nil,
externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"),
),
),
pageFooter(),
),
)
}

View File

@@ -1,21 +0,0 @@
package templates
import (
"github.com/chasefleming/elem-go"
)
// AuthWeb renders a page that instructs an administrator to run a CLI command
// to complete an authentication or registration flow.
// It is used by both the registration and auth-approve web handlers.
func AuthWeb(title, description, command string) *elem.Element {
return HtmlStructure(
elem.Title(nil, elem.Text(title+" - Headscale")),
mdTypesetBody(
headscaleLogo(),
H1(elem.Text(title)),
P(elem.Text(description)),
Pre(PreCode(command)),
pageFooter(),
),
)
}

View File

@@ -365,47 +365,6 @@ func orDivider() *elem.Element {
)
}
// successBox creates a green success feedback box with a checkmark icon.
// The heading is displayed as bold green text, and children are rendered below it.
// Pairs with warningBox for consistent feedback styling.
//
//nolint:unused // Used in auth_success.go template.
func successBox(heading string, children ...elem.Node) *elem.Element {
return elem.Div(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "flex",
styles.AlignItems: "center",
styles.Gap: spaceM,
styles.Padding: spaceL,
styles.BackgroundColor: colorSuccessLight,
styles.Border: "1px solid " + colorSuccess,
styles.BorderRadius: "0.5rem",
styles.MarginBottom: spaceXL,
}.ToInline(),
},
checkboxIcon(),
elem.Div(nil,
append([]elem.Node{
elem.Strong(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "block",
styles.Color: colorSuccess,
styles.FontSize: fontSizeH3,
styles.MarginBottom: spaceXS,
}.ToInline(),
}, elem.Text(heading)),
}, children...)...,
),
)
}
// checkboxIcon returns the success checkbox SVG icon as raw HTML.
func checkboxIcon() elem.Node {
return elem.Raw(`<svg id="checkbox" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 512 512">
<path d="M256 32C132.3 32 32 132.3 32 256s100.3 224 224 224 224-100.3 224-224S379.7 32 256 32zm114.9 149.1L231.8 359.6c-1.1 1.1-2.9 3.5-5.1 3.5-2.3 0-3.8-1.6-5.1-2.9-1.3-1.3-78.9-75.9-78.9-75.9l-1.5-1.5c-.6-.9-1.1-2-1.1-3.2 0-1.2.5-2.3 1.1-3.2.4-.4.7-.7 1.1-1.2 7.7-8.1 23.3-24.5 24.3-25.5 1.3-1.3 2.4-3 4.8-3 2.5 0 4.1 2.1 5.3 3.3 1.2 1.2 45 43.3 45 43.3l111.3-143c1-.8 2.2-1.4 3.5-1.4 1.3 0 2.5.5 3.5 1.3l30.6 24.1c.8 1 1.3 2.2 1.3 3.5.1 1.3-.4 2.4-1 3.3z"></path>
</svg>`)
}
// warningBox creates a warning message box with icon and content.
//
//nolint:unused // Used in apple.go template.

View File

@@ -0,0 +1,69 @@
package templates
import (
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
"github.com/chasefleming/elem-go/styles"
)
// checkboxIcon returns the success checkbox SVG icon as raw HTML.
func checkboxIcon() elem.Node {
return elem.Raw(`<svg id="checkbox" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 512 512">
<path d="M256 32C132.3 32 32 132.3 32 256s100.3 224 224 224 224-100.3 224-224S379.7 32 256 32zm114.9 149.1L231.8 359.6c-1.1 1.1-2.9 3.5-5.1 3.5-2.3 0-3.8-1.6-5.1-2.9-1.3-1.3-78.9-75.9-78.9-75.9l-1.5-1.5c-.6-.9-1.1-2-1.1-3.2 0-1.2.5-2.3 1.1-3.2.4-.4.7-.7 1.1-1.2 7.7-8.1 23.3-24.5 24.3-25.5 1.3-1.3 2.4-3 4.8-3 2.5 0 4.1 2.1 5.3 3.3 1.2 1.2 45 43.3 45 43.3l111.3-143c1-.8 2.2-1.4 3.5-1.4 1.3 0 2.5.5 3.5 1.3l30.6 24.1c.8 1 1.3 2.2 1.3 3.5.1 1.3-.4 2.4-1 3.3z"></path>
</svg>`)
}
// OIDCCallback renders the OIDC authentication success callback page.
func OIDCCallback(user, verb string) *elem.Element {
// Success message box
successBox := elem.Div(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "flex",
styles.AlignItems: "center",
styles.Gap: spaceM,
styles.Padding: spaceL,
styles.BackgroundColor: colorSuccessLight,
styles.Border: "1px solid " + colorSuccess,
styles.BorderRadius: "0.5rem",
styles.MarginBottom: spaceXL,
}.ToInline(),
},
checkboxIcon(),
elem.Div(nil,
elem.Strong(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "block",
styles.Color: colorSuccess,
styles.FontSize: fontSizeH3,
styles.MarginBottom: spaceXS,
}.ToInline(),
}, elem.Text("Signed in successfully")),
elem.P(attrs.Props{
attrs.Style: styles.Props{
styles.Margin: "0",
styles.Color: colorTextPrimary,
styles.FontSize: fontSizeBase,
}.ToInline(),
}, elem.Text(verb), elem.Text(" as "), elem.Strong(nil, elem.Text(user)), elem.Text(". You can now close this window.")),
),
)
return HtmlStructure(
elem.Title(nil, elem.Text("Headscale Authentication Succeeded")),
mdTypesetBody(
headscaleLogo(),
successBox,
H2(elem.Text("Getting started")),
P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")),
Ul(
elem.Li(nil,
externalLink("https://headscale.net/stable/", "Headscale documentation"),
),
elem.Li(nil,
externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"),
),
),
pageFooter(),
),
)
}

View File

@@ -0,0 +1,21 @@
package templates
import (
"fmt"
"github.com/chasefleming/elem-go"
"github.com/juanfont/headscale/hscontrol/types"
)
func RegisterWeb(registrationID types.RegistrationID) *elem.Element {
return HtmlStructure(
elem.Title(nil, elem.Text("Registration - Headscale")),
mdTypesetBody(
headscaleLogo(),
H1(elem.Text("Machine registration")),
P(elem.Text("Run the command below in the headscale server to add this machine to your network:")),
Pre(PreCode(fmt.Sprintf("headscale nodes register --key %s --user USERNAME", registrationID.String()))),
pageFooter(),
),
)
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
)
@@ -15,30 +16,12 @@ func TestTemplateHTMLConsistency(t *testing.T) {
html string
}{
{
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
},
{
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
},
{
name: "Windows Config",
@@ -89,30 +72,12 @@ func TestTemplateModernHTMLFeatures(t *testing.T) {
html string
}{
{
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
},
{
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
},
{
name: "Windows Config",
@@ -151,35 +116,16 @@ func TestTemplateExternalLinkSecurity(t *testing.T) {
externalURLs []string // URLs that should have security attributes
}{
{
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
externalURLs: []string{
"https://headscale.net/stable/",
"https://tailscale.com/kb/",
},
},
{
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
externalURLs: []string{}, // No external links
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
externalURLs: []string{}, // No external links
},
{
@@ -239,30 +185,12 @@ func TestTemplateAccessibilityAttributes(t *testing.T) {
html string
}{
{
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
},
{
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
},
{
name: "Windows Config",

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"runtime"
"strings"
"sync/atomic"
"time"
@@ -23,9 +22,8 @@ const (
// Common errors.
var (
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidAuthIDLength = errors.New("auth ID has invalid length")
ErrInvalidAuthIDPrefix = errors.New("auth ID has invalid prefix")
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
)
type StateUpdateType int
@@ -161,26 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
const (
authIDPrefix = "hskey-authreq-"
authIDRandomLength = 24
// AuthIDLength is the total length of an AuthID: 14 (prefix) + 24 (random).
AuthIDLength = 38
)
const RegistrationIDLength = 24
type AuthID string
type RegistrationID string
func NewAuthID() (AuthID, error) {
rid, err := util.GenerateRandomStringURLSafe(authIDRandomLength)
func NewRegistrationID() (RegistrationID, error) {
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength)
if err != nil {
return "", err
}
return AuthID(authIDPrefix + rid), nil
return RegistrationID(rid), nil
}
func MustAuthID() AuthID {
rid, err := NewAuthID()
func MustRegistrationID() RegistrationID {
rid, err := NewRegistrationID()
if err != nil {
panic(err)
}
@@ -188,106 +181,43 @@ func MustAuthID() AuthID {
return rid
}
func AuthIDFromString(str string) (AuthID, error) {
r := AuthID(str)
err := r.Validate()
if err != nil {
return "", err
func RegistrationIDFromString(str string) (RegistrationID, error) {
if len(str) != RegistrationIDLength {
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
}
return r, nil
return RegistrationID(str), nil
}
func (r AuthID) String() string {
func (r RegistrationID) String() string {
return string(r)
}
func (r AuthID) Validate() error {
if !strings.HasPrefix(string(r), authIDPrefix) {
return fmt.Errorf(
"%w: expected prefix %q",
ErrInvalidAuthIDPrefix, authIDPrefix,
)
}
if len(r) != AuthIDLength {
return fmt.Errorf(
"%w: expected %d, got %d",
ErrInvalidAuthIDLength, AuthIDLength, len(r),
)
}
return nil
type RegisterNode struct {
Node Node
Registered chan *Node
closed *atomic.Bool
}
// AuthRequest represent a pending authentication request from a user or a node.
// If it is a registration request, the node field will be populate with the node that is trying to register.
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
type AuthRequest struct {
node *Node
finished chan AuthVerdict
closed *atomic.Bool
}
func NewAuthRequest() AuthRequest {
return AuthRequest{
finished: make(chan AuthVerdict),
closed: &atomic.Bool{},
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
}
}
func NewRegisterAuthRequest(node Node) AuthRequest {
return AuthRequest{
node: &node,
finished: make(chan AuthVerdict),
closed: &atomic.Bool{},
}
}
// Node returns the node that is trying to register.
// It will panic if the AuthRequest is not a registration request.
// Can _only_ be used in the registration path.
func (rn *AuthRequest) Node() NodeView {
if rn.node == nil {
panic("Node can only be used in registration requests")
}
return rn.node.View()
}
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
func (rn *RegisterNode) SendAndClose(node *Node) {
if rn.closed.Swap(true) {
return
}
select {
case rn.finished <- verdict:
case rn.Registered <- node:
default:
}
close(rn.finished)
}
func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict {
return rn.finished
}
type AuthVerdict struct {
// Err is the error that occurred during the authentication process, if any.
// If Err is nil, the authentication process has succeeded.
// If Err is not nil, the authentication process has failed and the node should not be authenticated.
Err error
// Node is the node that has been authenticated.
// Node is only valid if the auth request was a registration request
// and the authentication process has succeeded.
Node NodeView
}
func (v AuthVerdict) Accept() bool {
return v.Err == nil
close(rn.Registered)
}
// DefaultBatcherWorkers returns the default number of batcher workers.

View File

@@ -295,8 +295,8 @@ func IsCI() bool {
// 3. If normalisation fails → generate invalid-<random> replacement
//
// Returns the guaranteed-valid hostname to use.
func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
if !hostinfo.Valid() || hostinfo.Hostname() == "" {
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
if hostinfo == nil || hostinfo.Hostname == "" {
key := cmp.Or(machineKey, nodeKey)
if key == "" {
return "unknown-node"
@@ -310,7 +310,7 @@ func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) s
return "node-" + keyPrefix
}
lowercased := strings.ToLower(hostinfo.Hostname())
lowercased := strings.ToLower(hostinfo.Hostname)
err := ValidateHostname(lowercased)
if err == nil {

View File

@@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.want, "invalid-") {
if !strings.HasPrefix(got, "invalid-") {
@@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.wantHostname, "invalid-") {
if !strings.HasPrefix(gotHostname, "invalid-") {
@@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := EnsureHostname(hostinfo.View(), "mkey", "nkey")
result := EnsureHostname(hostinfo, "mkey", "nkey")
if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
}
@@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
OS: "linux",
}
hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey")
if hostname1 != hostname2 {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)

View File

@@ -312,7 +312,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
}
// Register all clients as user1 (this is where cross-user registration happens)
// This simulates: headscale auth register --auth-id <id> --user user1
// This simulates: headscale nodes register --user user1 --key <key>
_ = scenario.runHeadscaleRegister("user1", body)
}

View File

@@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@@ -1100,11 +1100,11 @@ func TestNodeCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"auth",
"register",
"nodes",
"--user",
"node-user",
"--auth-id",
"register",
"--key",
regID,
"--output",
"json",
@@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) {
assert.Equal(t, "node-5", listAll[4].GetName())
otherUserRegIDs := []string{
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
}
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
@@ -1185,11 +1185,11 @@ func TestNodeCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"auth",
"register",
"nodes",
"--user",
"other-user",
"--auth-id",
"register",
"--key",
regID,
"--output",
"json",
@@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@@ -1359,11 +1359,11 @@ func TestNodeExpireCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"auth",
"register",
"nodes",
"--user",
"node-expire-user",
"--auth-id",
"register",
"--key",
regID,
"--output",
"json",
@@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@@ -1496,11 +1496,11 @@ func TestNodeRenameCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"auth",
"register",
"nodes",
"--user",
"node-rename-command",
"--auth-id",
"register",
"--key",
regID,
"--output",
"json",

View File

@@ -16,7 +16,6 @@ import (
type ControlServer interface {
Shutdown() (string, string, error)
SaveLog(path string) (string, string, error)
ReadLog() (string, string, error)
SaveProfile(path string) error
Execute(command []string) (string, error)
WriteFile(path string, content []byte) error

View File

@@ -699,18 +699,6 @@ func (t *HeadscaleInContainer) WriteLogs(stdout, stderr io.Writer) error {
return dockertestutil.WriteLog(t.pool, t.container, stdout, stderr)
}
// ReadLog returns the current stdout and stderr logs from the headscale container.
func (t *HeadscaleInContainer) ReadLog() (string, string, error) {
var stdout, stderr bytes.Buffer
err := dockertestutil.WriteLog(t.pool, t.container, &stdout, &stderr)
if err != nil {
return "", "", fmt.Errorf("reading container logs: %w", err)
}
return stdout.String(), stderr.String(), nil
}
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) {

View File

@@ -141,12 +141,6 @@ type ScenarioSpec struct {
// Versions is specific list of versions to use for the test.
Versions []string
// OIDCSkipUserCreation, if true, skips creating users via headscale CLI
// during environment setup. Useful for OIDC tests where the SSH policy
// references users by name, since OIDC login creates users automatically
// and pre-creating them via CLI causes duplicate user records.
OIDCSkipUserCreation bool
// OIDCUsers, if populated, will start a Mock OIDC server and populate
// the user login stack with the given users.
// If the NodesPerUser is set, it should align with this list to ensure
@@ -872,18 +866,9 @@ func (s *Scenario) createHeadscaleEnvWithTags(
}
for _, user := range s.spec.Users {
var u *v1.User
if s.spec.OIDCSkipUserCreation {
// Only register locally — OIDC login will create the headscale user.
s.mu.Lock()
s.users[user] = &User{Clients: make(map[string]TailscaleClient)}
s.mu.Unlock()
} else {
u, err = s.CreateUser(user)
if err != nil {
return err
}
u, err := s.CreateUser(user)
if err != nil {
return err
}
var userOpts []tsic.Option
@@ -1184,7 +1169,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
return errParseAuthPage
}
keySep := strings.Split(codeSep[0], "--auth-id ")
keySep := strings.Split(codeSep[0], "key ")
if len(keySep) != 2 {
return errParseAuthPage
}
@@ -1195,7 +1180,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr
_, err = headscale.Execute(
[]string{"headscale", "auth", "register", "--user", userStr, "--auth-id", key},
[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key},
)
if err != nil {
log.Printf("registering node: %s", err)

View File

@@ -3,16 +3,13 @@ package integration
import (
"fmt"
"log"
"net/url"
"strings"
"testing"
"time"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
@@ -582,676 +579,3 @@ func TestSSHAutogroupSelf(t *testing.T) {
}
}
}
type sshCheckResult struct {
stdout string
stderr string
err error
}
// doSSHCheck runs SSH in a goroutine with a longer timeout, returning a channel
// for the result. The SSH command will block while waiting for auth approval in
// check mode.
func doSSHCheck(
t *testing.T,
client TailscaleClient,
peer TailscaleClient,
) chan sshCheckResult {
t.Helper()
peerFQDN, _ := peer.FQDN()
command := []string{
"/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=30",
fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN),
"'hostname'",
}
log.Printf(
"[SSH check] Running from %s to %s",
client.Hostname(),
peer.Hostname(),
)
ch := make(chan sshCheckResult, 1)
go func() {
stdout, stderr, err := client.Execute(
command,
dockertestutil.ExecuteCommandTimeout(60*time.Second),
)
ch <- sshCheckResult{stdout, stderr, err}
}()
return ch
}
// findSSHCheckAuthID polls headscale container logs for the SSH action auth-id.
// The SSH action handler logs "SSH action follow-up" with the auth_id on the
// follow-up request (where auth_id is non-empty).
func findSSHCheckAuthID(t *testing.T, headscale ControlServer) string {
t.Helper()
var authID string
assert.EventuallyWithT(t, func(c *assert.CollectT) {
_, stderr, err := headscale.ReadLog()
assert.NoError(c, err)
for line := range strings.SplitSeq(stderr, "\n") {
if !strings.Contains(line, "SSH action follow-up") {
continue
}
if idx := strings.Index(line, "auth_id="); idx != -1 {
start := idx + len("auth_id=")
end := strings.IndexByte(line[start:], ' ')
if end == -1 {
end = len(line[start:])
}
authID = line[start : start+end]
}
}
assert.NotEmpty(c, authID, "auth-id not found in headscale logs")
}, 10*time.Second, 500*time.Millisecond, "waiting for SSH check auth-id in headscale logs")
return authID
}
// sshCheckPolicy returns a policy with SSH "check" mode for group:integration-test
// targeting autogroup:member and autogroup:tagged destinations.
func sshCheckPolicy() *policyv2.Policy {
return &policyv2.Policy{
Groups: policyv2.Groups{
policyv2.Group("group:integration-test"): []policyv2.Username{
policyv2.Username("user1@"),
},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
SSHs: []policyv2.SSH{
{
Action: "check",
Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")},
Destinations: policyv2.SSHDstAliases{
new(policyv2.AutoGroupMember),
new(policyv2.AutoGroupTagged),
},
Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")},
},
},
}
}
// sshCheckPolicyWithPeriod returns a policy with SSH "check" mode and a
// specified checkPeriod for session duration.
func sshCheckPolicyWithPeriod(period time.Duration) *policyv2.Policy {
return &policyv2.Policy{
Groups: policyv2.Groups{
policyv2.Group("group:integration-test"): []policyv2.Username{
policyv2.Username("user1@"),
},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
SSHs: []policyv2.SSH{
{
Action: "check",
Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")},
Destinations: policyv2.SSHDstAliases{
new(policyv2.AutoGroupMember),
new(policyv2.AutoGroupTagged),
},
Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")},
CheckPeriod: &policyv2.SSHCheckPeriod{Duration: period},
},
},
}
}
// findNewSSHCheckAuthID polls headscale logs for an SSH check auth-id
// that differs from excludeID. Used to verify re-authentication after
// session expiry.
func findNewSSHCheckAuthID(
t *testing.T,
headscale ControlServer,
excludeID string,
) string {
t.Helper()
var authID string
assert.EventuallyWithT(t, func(c *assert.CollectT) {
_, stderr, err := headscale.ReadLog()
assert.NoError(c, err)
for line := range strings.SplitSeq(stderr, "\n") {
if !strings.Contains(line, "SSH action follow-up") {
continue
}
if idx := strings.Index(line, "auth_id="); idx != -1 {
start := idx + len("auth_id=")
end := strings.IndexByte(line[start:], ' ')
if end == -1 {
end = len(line[start:])
}
id := line[start : start+end]
if id != excludeID {
authID = id
}
}
}
assert.NotEmpty(c, authID, "new auth-id not found in headscale logs")
}, 10*time.Second, 500*time.Millisecond, "waiting for new SSH check auth-id")
return authID
}
func TestSSHOneUserToOneCheckModeCLI(t *testing.T) {
IntegrationSkip(t)
scenario := sshScenario(t, sshCheckPolicy(), 1)
// defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// user1 can SSH (via check) to all peers
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
// Start SSH — will block waiting for check auth
sshResult := doSSHCheck(t, client, peer)
// Find the auth-id from headscale logs
authID := findSSHCheckAuthID(t, headscale)
// Approve via CLI
_, err := headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", authID,
},
)
require.NoError(t, err)
// Wait for SSH to complete
select {
case result := <-sshResult:
require.NoError(t, result.err)
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("SSH did not complete after auth approval")
}
}
}
// user2 cannot SSH — not in the check policy group
for _, client := range user2Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
assertSSHPermissionDenied(t, client, peer)
}
}
}
func TestSSHOneUserToOneCheckModeOIDC(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
OIDCSkipUserCreation: true,
OIDCUsers: []mockoidc.MockUser{
// First 2: consumed during node registration
oidcMockUser("user1", true),
oidcMockUser("user2", true),
// Extra: consumed during SSH check auth flows.
// Each SSH check pops one user from the queue.
oidcMockUser("user1", true),
},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
// defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
[]tsic.Option{
tsic.WithSSH(),
tsic.WithNetfilter("off"),
tsic.WithPackages("openssh"),
tsic.WithExtraCommands("adduser ssh-it-user"),
tsic.WithDockerWorkdir("/"),
},
hsic.WithACLPolicy(sshCheckPolicy()),
hsic.WithTestName("sshcheckoidc"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(scenario.mockOIDC.ClientSecret()),
),
)
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// user1 can SSH (via check) to all peers
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
// Start SSH — will block waiting for check auth
sshResult := doSSHCheck(t, client, peer)
// Find the auth-id from headscale logs
authID := findSSHCheckAuthID(t, headscale)
// Build auth URL and visit it to trigger OIDC flow.
// The mock OIDC server auto-authenticates from the user queue.
authURL := headscale.GetEndpoint() + "/auth/" + authID
parsedURL, err := url.Parse(authURL)
require.NoError(t, err)
_, err = doLoginURL("ssh-check-oidc", parsedURL)
require.NoError(t, err)
// Wait for SSH to complete
select {
case result := <-sshResult:
require.NoError(t, result.err)
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("SSH did not complete after OIDC auth")
}
}
}
// user2 cannot SSH — not in the check policy group
for _, client := range user2Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
assertSSHPermissionDenied(t, client, peer)
}
}
}
// TestSSHCheckModeUnapprovedTimeout verifies that SSH in check mode is rejected
// when nobody approves the auth request and the registration cache entry expires.
func TestSSHCheckModeUnapprovedTimeout(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
tsic.WithSSH(),
tsic.WithNetfilter("off"),
tsic.WithPackages("openssh"),
tsic.WithExtraCommands("adduser ssh-it-user"),
tsic.WithDockerWorkdir("/"),
},
hsic.WithACLPolicy(sshCheckPolicy()),
hsic.WithTestName("sshchecktimeout"),
hsic.WithConfigEnv(map[string]string{
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "15s",
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "5s",
}),
)
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// user1 attempts SSH — enters check flow, but nobody approves
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
sshResult := doSSHCheck(t, client, peer)
// Confirm the check flow was entered
_ = findSSHCheckAuthID(t, headscale)
// Do NOT approve — wait for cache expiry and SSH rejection
select {
case result := <-sshResult:
require.Error(t, result.err, "SSH should be rejected when unapproved")
assert.Empty(t, result.stdout, "no command output expected on rejection")
case <-time.After(60 * time.Second):
t.Fatal("SSH did not complete after cache expiry timeout")
}
}
}
// user2 still gets immediate Permission Denied
for _, client := range user2Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
assertSSHPermissionDenied(t, client, peer)
}
}
}
// TestSSHCheckModeCheckPeriodCLI verifies that after approval with a short
// checkPeriod, the session expires and the next SSH connection requires
// re-authentication via a new check flow.
func TestSSHCheckModeCheckPeriodCLI(t *testing.T) {
IntegrationSkip(t)
// 1 minute is the documented minimum checkPeriod
scenario := sshScenario(t, sshCheckPolicyWithPeriod(time.Minute), 1)
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// === Phase 1: First SSH check — approve, verify success ===
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
sshResult := doSSHCheck(t, client, peer)
firstAuthID := findSSHCheckAuthID(t, headscale)
_, err := headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", firstAuthID,
},
)
require.NoError(t, err)
select {
case result := <-sshResult:
require.NoError(t, result.err, "first SSH should succeed after approval")
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("first SSH did not complete after auth approval")
}
// === Phase 2: Wait for checkPeriod to expire ===
//nolint:forbidigo // Intentional sleep: waiting for the check period session
// to expire. This is a time-based expiry, not a pollable condition — the
// Tailscale client caches the approval for SessionDuration and only
// re-triggers the check flow after it elapses.
time.Sleep(70 * time.Second)
// === Phase 3: Second SSH — must re-authenticate ===
sshResult2 := doSSHCheck(t, client, peer)
secondAuthID := findNewSSHCheckAuthID(t, headscale, firstAuthID)
require.NotEqual(
t,
firstAuthID,
secondAuthID,
"second SSH should trigger a new auth flow after checkPeriod expiry",
)
_, err = headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", secondAuthID,
},
)
require.NoError(t, err)
select {
case result := <-sshResult2:
require.NoError(t, result.err, "second SSH should succeed after re-approval")
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("second SSH did not complete after re-auth approval")
}
}
}
}
// TestSSHCheckModeAutoApprove verifies that after SSH check approval, a second
// SSH within the checkPeriod is auto-approved without requiring manual approval.
func TestSSHCheckModeAutoApprove(t *testing.T) {
IntegrationSkip(t)
// 5 minute checkPeriod — long enough not to expire during test
scenario := sshScenario(t, sshCheckPolicyWithPeriod(5*time.Minute), 1)
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// === Phase 1: First SSH check — approve, verify success ===
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
sshResult := doSSHCheck(t, client, peer)
firstAuthID := findSSHCheckAuthID(t, headscale)
_, err := headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", firstAuthID,
},
)
require.NoError(t, err)
select {
case result := <-sshResult:
require.NoError(t, result.err, "first SSH should succeed after approval")
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("first SSH did not complete after auth approval")
}
// === Phase 2: Immediate retry — should auto-approve ===
result, _, err := doSSH(t, client, peer)
require.NoError(t, err, "second SSH should auto-approve without manual auth")
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result, "\n", ""),
)
}
}
}
// TestSSHCheckModeNegativeCLI verifies that `headscale auth reject`
// properly denies an SSH check.
func TestSSHCheckModeNegativeCLI(t *testing.T) {
IntegrationSkip(t)
scenario := sshScenario(t, sshCheckPolicy(), 1)
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
sshResult := doSSHCheck(t, client, peer)
authID := findSSHCheckAuthID(t, headscale)
// Reject via CLI
_, err := headscale.Execute(
[]string{
"headscale", "auth", "reject",
"--auth-id", authID,
},
)
require.NoError(t, err)
select {
case result := <-sshResult:
require.Error(t, result.err, "SSH should be rejected")
assert.Empty(t, result.stdout, "no command output expected on rejection")
case <-time.After(30 * time.Second):
t.Fatal("SSH did not complete after auth rejection")
}
}
}
}

View File

@@ -3122,7 +3122,7 @@ func TestTagsAuthKeyWithoutUserRejectsAdvertisedTags(t *testing.T) {
// TestTagsAuthKeyConvertToUserViaCLIRegister reproduces the panic from
// issue #3038: register a node with a tags-only preauthkey (no user), then
// convert it to a user-owned node via "headscale auth register --auth-id <id> --user <user>".
// convert it to a user-owned node via "headscale nodes register --user <user> --key ...".
// The crash happens in the mapper's generateUserProfiles when node.User is nil
// after the tag→user conversion in processReauthTags.
//

View File

@@ -1,26 +0,0 @@
syntax = "proto3";
package headscale.v1;
option go_package = "github.com/juanfont/headscale/gen/go/v1";
import "headscale/v1/node.proto";
message AuthRegisterRequest {
string user = 1;
string auth_id = 2;
}
message AuthRegisterResponse {
Node node = 1;
}
message AuthApproveRequest {
string auth_id = 1;
}
message AuthApproveResponse {}
message AuthRejectRequest {
string auth_id = 1;
}
message AuthRejectResponse {}

View File

@@ -8,7 +8,6 @@ import "headscale/v1/user.proto";
import "headscale/v1/preauthkey.proto";
import "headscale/v1/node.proto";
import "headscale/v1/apikey.proto";
import "headscale/v1/auth.proto";
import "headscale/v1/policy.proto";
service HeadscaleService {
@@ -140,29 +139,6 @@ service HeadscaleService {
// --- Node end ---
// --- Auth start ---
rpc AuthRegister(AuthRegisterRequest) returns (AuthRegisterResponse) {
option (google.api.http) = {
post : "/api/v1/auth/register"
body : "*"
};
}
rpc AuthApprove(AuthApproveRequest) returns (AuthApproveResponse) {
option (google.api.http) = {
post : "/api/v1/auth/approve"
body : "*"
};
}
rpc AuthReject(AuthRejectRequest) returns (AuthRejectResponse) {
option (google.api.http) = {
post : "/api/v1/auth/reject"
body : "*"
};
}
// --- Auth end ---
// --- ApiKeys start ---
rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse) {
option (google.api.http) = {