diff --git a/CHANGELOG.md b/CHANGELOG.md index e5a4163b..2cd56c09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,7 @@ release. - Detect when only node endpoints or DERP region changed and send PeerChangedPatch responses instead of full map updates, reducing bandwidth and improving performance +- Tags can now be tagOwner of other tags [#2930](https://github.com/juanfont/headscale/pull/2930) ## 0.27.2 (2025-xx-xx) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 23e1f226..e1e25821 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -3234,7 +3234,7 @@ func TestIssue2830_ExistingNodeReregistersWithExpiredKey(t *testing.T) { // Create a valid key (will expire it later) expiry := time.Now().Add(1 * time.Hour) - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, &expiry, nil) + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiry, nil) require.NoError(t, err) machineKey := key.NewMachine() diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 94e631e7..b348c0ab 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -464,14 +464,14 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { // test-2 has a router device with tag:node-router test2RouterNode := &types.Node{ - ID: 2, - Hostname: "test-2-router", - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), - Tags: []string{"tag:node-router"}, - Hostinfo: &tailcfg.Hostinfo{}, + ID: 2, + Hostname: "test-2-router", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:node-router"}, + Hostinfo: &tailcfg.Hostinfo{}, } nodes := types.Nodes{test1Node, test2RouterNode} @@ -537,8 +537,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: users[0], - UserID: users[0].ID, + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -547,8 +547,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { Hostname: "test-2-device", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: users[1], - UserID: users[1].ID, + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 0635a557..0c4aec38 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1,6 +1,7 @@ package v2 import ( + "cmp" "errors" "fmt" "net/netip" @@ -9,7 +10,6 @@ import ( "strings" "github.com/go-json-experiment/json" - "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" @@ -34,6 +34,10 @@ const Wildcard = Asterix(0) var ErrAutogroupSelfRequiresPerNodeResolution = errors.New("autogroup:self requires per-node resolution and cannot be resolved in this context") +var ErrCircularReference = errors.New("circular reference detected") + +var ErrUndefinedTagReference = errors.New("references undefined tag") + type Asterix int func (a Asterix) Validate() error { @@ -341,6 +345,10 @@ func (t Tag) CanBeAutoApprover() bool { return true } +func (t Tag) CanBeTagOwner() bool { + return true +} + func (t Tag) String() string { return string(t) } @@ -915,6 +923,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { type Owner interface { CanBeTagOwner() bool UnmarshalJSON([]byte) error + String() string } // OwnerEnc is used to deserialize a Owner. @@ -963,6 +972,8 @@ func (o Owners) MarshalJSON() ([]byte, error) { owners[i] = string(*v) case *Group: owners[i] = string(*v) + case *Tag: + owners[i] = string(*v) default: return nil, fmt.Errorf("unknown owner type: %T", v) } @@ -977,6 +988,8 @@ func parseOwner(s string) (Owner, error) { return ptr.To(Username(s)), nil case isGroup(s): return ptr.To(Group(s)), nil + case isTag(s): + return ptr.To(Tag(s)), nil } return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: @@ -1134,6 +1147,8 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { ownerStrs[i] = string(*v) case *Group: ownerStrs[i] = string(*v) + case *Tag: + ownerStrs[i] = string(*v) default: return nil, fmt.Errorf("unknown owner type: %T", v) } @@ -1167,23 +1182,38 @@ func (to TagOwners) Contains(tagOwner *Tag) error { // It is intended for internal use in a PolicyManager. func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) { if p == nil { - return nil, nil + return make(map[Tag]*netipx.IPSet), nil + } + + if len(p.TagOwners) == 0 { + return make(map[Tag]*netipx.IPSet), nil } ret := make(map[Tag]*netipx.IPSet) - for tag, owners := range p.TagOwners { + tagOwners, err := flattenTagOwners(p.TagOwners) + if err != nil { + return nil, err + } + + for tag, owners := range tagOwners { var ips netipx.IPSetBuilder for _, owner := range owners { - o, ok := owner.(Alias) - if !ok { + switch o := owner.(type) { + case *Tag: + // After flattening, Tag types should not appear in the owners list. + // If they do, skip them as they represent already-resolved references. + + case Alias: + // If it does not resolve, that means the tag is not associated with any IP addresses. + resolved, _ := o.Resolve(p, users, nodes) + ips.AddSet(resolved) + + default: // Should never happen return nil, fmt.Errorf("owner %v is not an Alias", owner) } - // If it does not resolve, that means the tag is not associated with any IP addresses. - resolved, _ := o.Resolve(p, users, nodes) - ips.AddSet(resolved) } ipSet, err := ips.IPSet() @@ -1197,6 +1227,79 @@ func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.Node return ret, nil } +// flattenTags flattens the TagOwners by resolving nested tags and detecting cycles. +// It will return a Owners list where all the Tag types have been resolved to their underlying Owners. +func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) { + if visiting[tag] { + cycleStart := 0 + + for i, t := range chain { + if t == tag { + cycleStart = i + break + } + } + + cycleTags := make([]string, len(chain[cycleStart:])) + for i, t := range chain[cycleStart:] { + cycleTags[i] = string(t) + } + + slices.Sort(cycleTags) + + return nil, fmt.Errorf("%w: %s", ErrCircularReference, strings.Join(cycleTags, " -> ")) + } + + visiting[tag] = true + + chain = append(chain, tag) + defer delete(visiting, tag) + + var result Owners + + for _, owner := range tagOwners[tag] { + switch o := owner.(type) { + case *Tag: + if _, ok := tagOwners[*o]; !ok { + return nil, fmt.Errorf("tag %q %w %q", tag, ErrUndefinedTagReference, *o) + } + + nested, err := flattenTags(tagOwners, *o, visiting, chain) + if err != nil { + return nil, err + } + + result = append(result, nested...) + default: + result = append(result, owner) + } + } + + return result, nil +} + +// flattenTagOwners flattens all TagOwners by resolving nested tags and detecting cycles. +// It will return a new TagOwners map where all the Tag types have been resolved to their underlying Owners. +func flattenTagOwners(tagOwners TagOwners) (TagOwners, error) { + ret := make(TagOwners) + + for tag := range tagOwners { + flattened, err := flattenTags(tagOwners, tag, make(map[Tag]bool), nil) + if err != nil { + return nil, err + } + + slices.SortFunc(flattened, func(a, b Owner) int { + return cmp.Compare(a.String(), b.String()) + }) + ret[tag] = slices.CompactFunc(flattened, func(a, b Owner) bool { + return a.String() == b.String() + }) + } + + return ret, nil +} + type AutoApproverPolicy struct { Routes map[netip.Prefix]AutoApprovers `json:"routes,omitempty"` ExitNode AutoApprovers `json:"exitNode,omitempty"` @@ -1844,10 +1947,23 @@ func (p *Policy) validate() error { if err := p.Groups.Contains(g); err != nil { errs = append(errs, err) } + case *Tag: + t := tagOwner + + err := p.TagOwners.Contains(t) + if err != nil { + errs = append(errs, err) + } } } } + // Validate tag ownership chains for circular references and undefined tags. + _, err := flattenTagOwners(p.TagOwners) + if err != nil { + errs = append(errs, err) + } + for _, approvers := range p.AutoApprovers.Routes { for _, approver := range approvers { switch approver := approver.(type) { diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 2d379b4d..a5e5e8d2 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1470,6 +1470,57 @@ func TestUnmarshalPolicy(t *testing.T) { }, }, }, + { + name: "tags-can-own-other-tags", + input: ` +{ + "tagOwners": { + "tag:bigbrother": [], + "tag:smallbrother": ["tag:bigbrother"], + }, + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["tag:smallbrother:9000"] + } + ] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): {}, + Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + }, + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Tag("tag:smallbrother")), + Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}}, + }, + }, + }, + }, + }, + }, + { + name: "tag-owner-references-undefined-tag", + input: ` +{ + "tagOwners": { + "tag:child": ["tag:nonexistent"], + }, +} +`, + wantErr: `tag "tag:child" references undefined tag "tag:nonexistent"`, + }, } cmps := append(util.Comparers, @@ -1596,7 +1647,7 @@ func TestResolvePolicy(t *testing.T) { { User: ptr.To(testuser), Tags: []string{"tag:anything"}, - IPv4: ap("100.100.101.2"), + IPv4: ap("100.100.101.2"), }, // not matching because it's tagged (tags copied from AuthKey) { @@ -1628,7 +1679,7 @@ func TestResolvePolicy(t *testing.T) { { User: ptr.To(groupuser), Tags: []string{"tag:anything"}, - IPv4: ap("100.100.101.5"), + IPv4: ap("100.100.101.5"), }, // not matching because it's tagged (tags copied from AuthKey) { @@ -1665,7 +1716,7 @@ func TestResolvePolicy(t *testing.T) { // Not matching forced tags { Tags: []string{"tag:anything"}, - IPv4: ap("100.100.101.10"), + IPv4: ap("100.100.101.10"), }, // not matching pak tag { @@ -1677,7 +1728,7 @@ func TestResolvePolicy(t *testing.T) { // Not matching forced tags { Tags: []string{"tag:test"}, - IPv4: ap("100.100.101.234"), + IPv4: ap("100.100.101.234"), }, // matching tag (tags copied from AuthKey during registration) { @@ -1689,6 +1740,52 @@ func TestResolvePolicy(t *testing.T) { pol: &Policy{}, want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")}, }, + { + name: "tag-owned-by-tag-call-child", + toResolve: tp("tag:smallbrother"), + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): {}, + Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + }, + }, + nodes: types.Nodes{ + // Should not match as we resolve the "child" tag. + { + Tags: []string{"tag:bigbrother"}, + IPv4: ap("100.100.101.234"), + }, + // Should match. + { + Tags: []string{"tag:smallbrother"}, + IPv4: ap("100.100.101.239"), + }, + }, + want: []netip.Prefix{mp("100.100.101.239/32")}, + }, + { + name: "tag-owned-by-tag-call-parent", + toResolve: tp("tag:bigbrother"), + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): {}, + Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + }, + }, + nodes: types.Nodes{ + // Should match - we are resolving "tag:bigbrother" which this node has. + { + Tags: []string{"tag:bigbrother"}, + IPv4: ap("100.100.101.234"), + }, + // Should not match - this node has "tag:smallbrother", not the tag we're resolving. + { + Tags: []string{"tag:smallbrother"}, + IPv4: ap("100.100.101.239"), + }, + }, + want: []netip.Prefix{mp("100.100.101.234/32")}, + }, { name: "empty-policy", toResolve: pp("100.100.101.101/32"), @@ -1747,7 +1844,7 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ { Tags: []string{"tag:test"}, - IPv4: ap("100.100.101.234"), + IPv4: ap("100.100.101.234"), }, }, }, @@ -1774,7 +1871,7 @@ func TestResolvePolicy(t *testing.T) { { User: ptr.To(testuser), Tags: []string{"tag:test"}, - IPv4: ap("100.100.101.2"), + IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be excluded) { @@ -1833,7 +1930,7 @@ func TestResolvePolicy(t *testing.T) { { User: ptr.To(testuser), Tags: []string{"tag:test"}, - IPv4: ap("100.100.101.2"), + IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be included) { @@ -1871,7 +1968,7 @@ func TestResolvePolicy(t *testing.T) { { User: ptr.To(testuser), Tags: []string{"tag:test", "tag:other"}, - IPv4: ap("100.100.101.7"), + IPv4: ap("100.100.101.7"), }, }, pol: &Policy{ @@ -1900,7 +1997,7 @@ func TestResolvePolicy(t *testing.T) { { User: ptr.To(testuser), Tags: []string{"tag:test"}, - IPv4: ap("100.100.101.3"), + IPv4: ap("100.100.101.3"), }, { User: ptr.To(testuser2), @@ -1976,11 +2073,11 @@ func TestResolveAutoApprovers(t *testing.T) { User: &users[2], }, { - IPv4: ap("100.64.0.4"), + IPv4: ap("100.64.0.4"), Tags: []string{"tag:testtag"}, }, { - IPv4: ap("100.64.0.5"), + IPv4: ap("100.64.0.5"), Tags: []string{"tag:exittest"}, }, } @@ -2474,6 +2571,20 @@ func TestResolveTagOwners(t *testing.T) { }, wantErr: false, }, + { + name: "tag-owns-tag", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:bigbrother"): mustIPSet("100.64.0.1/32"), + Tag("tag:smallbrother"): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, } cmps := append(util.Comparers, cmp.Comparer(ipSetComparer)) @@ -2936,3 +3047,147 @@ func mustParseAlias(s string) Alias { } return alias } + +func TestFlattenTagOwners(t *testing.T) { + tests := []struct { + name string + input TagOwners + want TagOwners + wantErr string + }{ + { + name: "tag-owns-tag", + input: TagOwners{ + Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + }, + want: TagOwners{ + Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{ptr.To(Group("group:user1"))}, + }, + wantErr: "", + }, + { + name: "circular-reference", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + }, + want: nil, + wantErr: "circular reference detected: tag:a -> tag:b", + }, + { + name: "mixed-owners", + input: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, + Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + }, + want: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + }, + wantErr: "", + }, + { + name: "mixed-dupe-owners", + input: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, + Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + }, + want: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + }, + wantErr: "", + }, + { + name: "no-tag-owners", + input: TagOwners{ + Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + }, + want: TagOwners{ + Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + }, + wantErr: "", + }, + { + name: "tag-long-owner-chain", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, + Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, + Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, + Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + }, + want: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:b"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:c"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:d"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:e"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:f"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:g"): Owners{ptr.To(Group("group:user1"))}, + }, + wantErr: "", + }, + { + name: "tag-long-circular-chain", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:g"))}, + Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, + Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, + Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, + Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + }, + wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g", + }, + { + name: "undefined-tag-reference", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:nonexistent"))}, + }, + wantErr: `tag "tag:a" references undefined tag "tag:nonexistent"`, + }, + { + name: "tag-with-empty-owners-is-valid", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:b"): Owners{}, // empty owners but exists + }, + want: TagOwners{ + Tag("tag:a"): nil, + Tag("tag:b"): nil, + }, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := flattenTagOwners(tt.input) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("flattenTagOwners() expected error %q, got nil", tt.wantErr) + } + + if err.Error() != tt.wantErr { + t.Fatalf("flattenTagOwners() expected error %q, got %q", tt.wantErr, err.Error()) + } + + return + } + + if err != nil { + t.Fatalf("flattenTagOwners() unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("flattenTagOwners() mismatch (-want +got):\n%s", diff) + } + }) + } +}