policy: allow tags to own tags (#2930)

This commit is contained in:
Kristoffer Dalby
2025-12-06 10:23:35 +01:00
committed by GitHub
parent eb788cd007
commit 15c84b34e0
5 changed files with 404 additions and 32 deletions

View File

@@ -58,6 +58,7 @@ release.
- Detect when only node endpoints or DERP region changed and send
PeerChangedPatch responses instead of full map updates, reducing bandwidth
and improving performance
- Tags can now be tagOwner of other tags [#2930](https://github.com/juanfont/headscale/pull/2930)
## 0.27.2 (2025-xx-xx)

View File

@@ -3234,7 +3234,7 @@ func TestIssue2830_ExistingNodeReregistersWithExpiredKey(t *testing.T) {
// Create a valid key (will expire it later)
expiry := time.Now().Add(1 * time.Hour)
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiry, nil)
require.NoError(t, err)
machineKey := key.NewMachine()

View File

@@ -464,14 +464,14 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
// test-2 has a router device with tag:node-router
test2RouterNode := &types.Node{
ID: 2,
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
ID: 2,
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
}
nodes := types.Nodes{test1Node, test2RouterNode}
@@ -537,8 +537,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
Hostname: "test-1-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[0],
UserID: users[0].ID,
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -547,8 +547,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
Hostname: "test-2-device",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
UserID: users[1].ID,
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Hostinfo: &tailcfg.Hostinfo{},
}

View File

@@ -1,6 +1,7 @@
package v2
import (
"cmp"
"errors"
"fmt"
"net/netip"
@@ -9,7 +10,6 @@ import (
"strings"
"github.com/go-json-experiment/json"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
@@ -34,6 +34,10 @@ const Wildcard = Asterix(0)
var ErrAutogroupSelfRequiresPerNodeResolution = errors.New("autogroup:self requires per-node resolution and cannot be resolved in this context")
var ErrCircularReference = errors.New("circular reference detected")
var ErrUndefinedTagReference = errors.New("references undefined tag")
type Asterix int
func (a Asterix) Validate() error {
@@ -341,6 +345,10 @@ func (t Tag) CanBeAutoApprover() bool {
return true
}
func (t Tag) CanBeTagOwner() bool {
return true
}
func (t Tag) String() string {
return string(t)
}
@@ -915,6 +923,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
type Owner interface {
CanBeTagOwner() bool
UnmarshalJSON([]byte) error
String() string
}
// OwnerEnc is used to deserialize a Owner.
@@ -963,6 +972,8 @@ func (o Owners) MarshalJSON() ([]byte, error) {
owners[i] = string(*v)
case *Group:
owners[i] = string(*v)
case *Tag:
owners[i] = string(*v)
default:
return nil, fmt.Errorf("unknown owner type: %T", v)
}
@@ -977,6 +988,8 @@ func parseOwner(s string) (Owner, error) {
return ptr.To(Username(s)), nil
case isGroup(s):
return ptr.To(Group(s)), nil
case isTag(s):
return ptr.To(Tag(s)), nil
}
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
@@ -1134,6 +1147,8 @@ func (to TagOwners) MarshalJSON() ([]byte, error) {
ownerStrs[i] = string(*v)
case *Group:
ownerStrs[i] = string(*v)
case *Tag:
ownerStrs[i] = string(*v)
default:
return nil, fmt.Errorf("unknown owner type: %T", v)
}
@@ -1167,23 +1182,38 @@ func (to TagOwners) Contains(tagOwner *Tag) error {
// It is intended for internal use in a PolicyManager.
func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) {
if p == nil {
return nil, nil
return make(map[Tag]*netipx.IPSet), nil
}
if len(p.TagOwners) == 0 {
return make(map[Tag]*netipx.IPSet), nil
}
ret := make(map[Tag]*netipx.IPSet)
for tag, owners := range p.TagOwners {
tagOwners, err := flattenTagOwners(p.TagOwners)
if err != nil {
return nil, err
}
for tag, owners := range tagOwners {
var ips netipx.IPSetBuilder
for _, owner := range owners {
o, ok := owner.(Alias)
if !ok {
switch o := owner.(type) {
case *Tag:
// After flattening, Tag types should not appear in the owners list.
// If they do, skip them as they represent already-resolved references.
case Alias:
// If it does not resolve, that means the tag is not associated with any IP addresses.
resolved, _ := o.Resolve(p, users, nodes)
ips.AddSet(resolved)
default:
// Should never happen
return nil, fmt.Errorf("owner %v is not an Alias", owner)
}
// If it does not resolve, that means the tag is not associated with any IP addresses.
resolved, _ := o.Resolve(p, users, nodes)
ips.AddSet(resolved)
}
ipSet, err := ips.IPSet()
@@ -1197,6 +1227,79 @@ func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.Node
return ret, nil
}
// flattenTags flattens the TagOwners by resolving nested tags and detecting cycles.
// It will return a Owners list where all the Tag types have been resolved to their underlying Owners.
func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) {
if visiting[tag] {
cycleStart := 0
for i, t := range chain {
if t == tag {
cycleStart = i
break
}
}
cycleTags := make([]string, len(chain[cycleStart:]))
for i, t := range chain[cycleStart:] {
cycleTags[i] = string(t)
}
slices.Sort(cycleTags)
return nil, fmt.Errorf("%w: %s", ErrCircularReference, strings.Join(cycleTags, " -> "))
}
visiting[tag] = true
chain = append(chain, tag)
defer delete(visiting, tag)
var result Owners
for _, owner := range tagOwners[tag] {
switch o := owner.(type) {
case *Tag:
if _, ok := tagOwners[*o]; !ok {
return nil, fmt.Errorf("tag %q %w %q", tag, ErrUndefinedTagReference, *o)
}
nested, err := flattenTags(tagOwners, *o, visiting, chain)
if err != nil {
return nil, err
}
result = append(result, nested...)
default:
result = append(result, owner)
}
}
return result, nil
}
// flattenTagOwners flattens all TagOwners by resolving nested tags and detecting cycles.
// It will return a new TagOwners map where all the Tag types have been resolved to their underlying Owners.
func flattenTagOwners(tagOwners TagOwners) (TagOwners, error) {
ret := make(TagOwners)
for tag := range tagOwners {
flattened, err := flattenTags(tagOwners, tag, make(map[Tag]bool), nil)
if err != nil {
return nil, err
}
slices.SortFunc(flattened, func(a, b Owner) int {
return cmp.Compare(a.String(), b.String())
})
ret[tag] = slices.CompactFunc(flattened, func(a, b Owner) bool {
return a.String() == b.String()
})
}
return ret, nil
}
type AutoApproverPolicy struct {
Routes map[netip.Prefix]AutoApprovers `json:"routes,omitempty"`
ExitNode AutoApprovers `json:"exitNode,omitempty"`
@@ -1844,10 +1947,23 @@ func (p *Policy) validate() error {
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
t := tagOwner
err := p.TagOwners.Contains(t)
if err != nil {
errs = append(errs, err)
}
}
}
}
// Validate tag ownership chains for circular references and undefined tags.
_, err := flattenTagOwners(p.TagOwners)
if err != nil {
errs = append(errs, err)
}
for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers {
switch approver := approver.(type) {

View File

@@ -1470,6 +1470,57 @@ func TestUnmarshalPolicy(t *testing.T) {
},
},
},
{
name: "tags-can-own-other-tags",
input: `
{
"tagOwners": {
"tag:bigbrother": [],
"tag:smallbrother": ["tag:bigbrother"],
},
"acls": [
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["tag:smallbrother:9000"]
}
]
}
`,
want: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): {},
Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))},
},
ACLs: []ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Tag("tag:smallbrother")),
Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}},
},
},
},
},
},
},
{
name: "tag-owner-references-undefined-tag",
input: `
{
"tagOwners": {
"tag:child": ["tag:nonexistent"],
},
}
`,
wantErr: `tag "tag:child" references undefined tag "tag:nonexistent"`,
},
}
cmps := append(util.Comparers,
@@ -1596,7 +1647,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
IPv4: ap("100.100.101.2"),
},
// not matching because it's tagged (tags copied from AuthKey)
{
@@ -1628,7 +1679,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(groupuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
IPv4: ap("100.100.101.5"),
},
// not matching because it's tagged (tags copied from AuthKey)
{
@@ -1665,7 +1716,7 @@ func TestResolvePolicy(t *testing.T) {
// Not matching forced tags
{
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.10"),
IPv4: ap("100.100.101.10"),
},
// not matching pak tag
{
@@ -1677,7 +1728,7 @@ func TestResolvePolicy(t *testing.T) {
// Not matching forced tags
{
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
IPv4: ap("100.100.101.234"),
},
// matching tag (tags copied from AuthKey during registration)
{
@@ -1689,6 +1740,52 @@ func TestResolvePolicy(t *testing.T) {
pol: &Policy{},
want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")},
},
{
name: "tag-owned-by-tag-call-child",
toResolve: tp("tag:smallbrother"),
pol: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): {},
Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))},
},
},
nodes: types.Nodes{
// Should not match as we resolve the "child" tag.
{
Tags: []string{"tag:bigbrother"},
IPv4: ap("100.100.101.234"),
},
// Should match.
{
Tags: []string{"tag:smallbrother"},
IPv4: ap("100.100.101.239"),
},
},
want: []netip.Prefix{mp("100.100.101.239/32")},
},
{
name: "tag-owned-by-tag-call-parent",
toResolve: tp("tag:bigbrother"),
pol: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): {},
Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))},
},
},
nodes: types.Nodes{
// Should match - we are resolving "tag:bigbrother" which this node has.
{
Tags: []string{"tag:bigbrother"},
IPv4: ap("100.100.101.234"),
},
// Should not match - this node has "tag:smallbrother", not the tag we're resolving.
{
Tags: []string{"tag:smallbrother"},
IPv4: ap("100.100.101.239"),
},
},
want: []netip.Prefix{mp("100.100.101.234/32")},
},
{
name: "empty-policy",
toResolve: pp("100.100.101.101/32"),
@@ -1747,7 +1844,7 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
{
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
IPv4: ap("100.100.101.234"),
},
},
},
@@ -1774,7 +1871,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be excluded)
{
@@ -1833,7 +1930,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be included)
{
@@ -1871,7 +1968,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:test", "tag:other"},
IPv4: ap("100.100.101.7"),
IPv4: ap("100.100.101.7"),
},
},
pol: &Policy{
@@ -1900,7 +1997,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.3"),
IPv4: ap("100.100.101.3"),
},
{
User: ptr.To(testuser2),
@@ -1976,11 +2073,11 @@ func TestResolveAutoApprovers(t *testing.T) {
User: &users[2],
},
{
IPv4: ap("100.64.0.4"),
IPv4: ap("100.64.0.4"),
Tags: []string{"tag:testtag"},
},
{
IPv4: ap("100.64.0.5"),
IPv4: ap("100.64.0.5"),
Tags: []string{"tag:exittest"},
},
}
@@ -2474,6 +2571,20 @@ func TestResolveTagOwners(t *testing.T) {
},
wantErr: false,
},
{
name: "tag-owns-tag",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): Owners{ptr.To(Username("user1@"))},
Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))},
},
},
want: map[Tag]*netipx.IPSet{
Tag("tag:bigbrother"): mustIPSet("100.64.0.1/32"),
Tag("tag:smallbrother"): mustIPSet("100.64.0.1/32"),
},
wantErr: false,
},
}
cmps := append(util.Comparers, cmp.Comparer(ipSetComparer))
@@ -2936,3 +3047,147 @@ func mustParseAlias(s string) Alias {
}
return alias
}
func TestFlattenTagOwners(t *testing.T) {
tests := []struct {
name string
input TagOwners
want TagOwners
wantErr string
}{
{
name: "tag-owns-tag",
input: TagOwners{
Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))},
},
want: TagOwners{
Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:smallbrother"): Owners{ptr.To(Group("group:user1"))},
},
wantErr: "",
},
{
name: "circular-reference",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))},
},
want: nil,
wantErr: "circular reference detected: tag:a -> tag:b",
},
{
name: "mixed-owners",
input: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))},
Tag("tag:y"): Owners{ptr.To(Username("user2@"))},
},
want: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))},
Tag("tag:y"): Owners{ptr.To(Username("user2@"))},
},
wantErr: "",
},
{
name: "mixed-dupe-owners",
input: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))},
Tag("tag:y"): Owners{ptr.To(Username("user1@"))},
},
want: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@"))},
Tag("tag:y"): Owners{ptr.To(Username("user1@"))},
},
wantErr: "",
},
{
name: "no-tag-owners",
input: TagOwners{
Tag("tag:solo"): Owners{ptr.To(Username("user1@"))},
},
want: TagOwners{
Tag("tag:solo"): Owners{ptr.To(Username("user1@"))},
},
wantErr: "",
},
{
name: "tag-long-owner-chain",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))},
Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))},
Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))},
Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))},
Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))},
},
want: TagOwners{
Tag("tag:a"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:b"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:c"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:d"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:e"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:f"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:g"): Owners{ptr.To(Group("group:user1"))},
},
wantErr: "",
},
{
name: "tag-long-circular-chain",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:g"))},
Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))},
Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))},
Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))},
Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))},
Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))},
},
wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g",
},
{
name: "undefined-tag-reference",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:nonexistent"))},
},
wantErr: `tag "tag:a" references undefined tag "tag:nonexistent"`,
},
{
name: "tag-with-empty-owners-is-valid",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:b"): Owners{}, // empty owners but exists
},
want: TagOwners{
Tag("tag:a"): nil,
Tag("tag:b"): nil,
},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := flattenTagOwners(tt.input)
if tt.wantErr != "" {
if err == nil {
t.Fatalf("flattenTagOwners() expected error %q, got nil", tt.wantErr)
}
if err.Error() != tt.wantErr {
t.Fatalf("flattenTagOwners() expected error %q, got %q", tt.wantErr, err.Error())
}
return
}
if err != nil {
t.Fatalf("flattenTagOwners() unexpected error: %v", err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("flattenTagOwners() mismatch (-want +got):\n%s", diff)
}
})
}
}