From f01052c85f8026613aa89c7a1647b3678a567bff Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 4 Mar 2026 16:16:40 +0100 Subject: [PATCH] speculative new datastruct, fix ip range return Signed-off-by: Kristoffer Dalby --- hscontrol/policy/v2/filter.go | 63 ++++-- hscontrol/policy/v2/policy.go | 4 +- .../policy/v2/tailscale_grants_compat_test.go | 85 -------- hscontrol/policy/v2/types.go | 194 +++++++++++++++--- 4 files changed, 217 insertions(+), 129 deletions(-) diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index aa2d5355..3d366334 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -45,7 +45,7 @@ func (pol *Policy) compileFilterRules( log.Trace().Caller().Err(err).Msgf("resolving source ips") } - if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { + if srcIPs.Empty() { continue } @@ -54,7 +54,7 @@ func (pol *Policy) compileFilterRules( if len(destPorts) > 0 { rules = append(rules, tailcfg.FilterRule{ - SrcIPs: ipSetToPrefixStringList(srcIPs), + SrcIPs: srcIPs.Strings(), DstPorts: destPorts, IPProto: ipp.Protocol.toIANAProtocolNumbers(), }) @@ -77,7 +77,7 @@ func (pol *Policy) compileFilterRules( } rules = append(rules, tailcfg.FilterRule{ - SrcIPs: ipSetToPrefixStringList(srcIPs), + SrcIPs: srcIPs.Strings(), CapGrant: capGrants, }) } @@ -131,6 +131,10 @@ func (pol *Policy) destinationsToNetPortRange( IP: pref.String(), Ports: port, } + // Drop the prefix bits if its a single IP. + if pref.IsSingleIP() { + pr.IP = pref.Addr().String() + } ret = append(ret, pr) } } @@ -197,7 +201,7 @@ func (pol *Policy) compileGrantWithAutogroupSelf( var rules []tailcfg.FilterRule - var resolvedSrcIPs []*netipx.IPSet + var resolvedSrcs []ResolvedAddresses for _, src := range grant.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -210,11 +214,11 @@ func (pol *Policy) compileGrantWithAutogroupSelf( } if ips != nil { - resolvedSrcIPs = append(resolvedSrcIPs, ips) + resolvedSrcs = append(resolvedSrcs, ips) } } - if len(resolvedSrcIPs) == 0 { + if len(resolvedSrcs) == 0 { return rules, nil } @@ -235,7 +239,7 @@ func (pol *Policy) compileGrantWithAutogroupSelf( // Filter sources to only same-user untagged devices var srcIPs netipx.IPSetBuilder - for _, ips := range resolvedSrcIPs { + for _, ips := range resolvedSrcs { for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set if slices.ContainsFunc(n.IPs(), ips.Contains) { @@ -244,12 +248,12 @@ func (pol *Policy) compileGrantWithAutogroupSelf( } } - srcSet, err := srcIPs.IPSet() + srcResolved, err := newResolved(&srcIPs) if err != nil { return nil, err } - if srcSet != nil && len(srcSet.Prefixes()) > 0 { + if !srcResolved.Empty() { var destPorts []tailcfg.NetPortRange for _, n := range sameUserNodes { @@ -265,7 +269,7 @@ func (pol *Policy) compileGrantWithAutogroupSelf( if len(destPorts) > 0 { rules = append(rules, tailcfg.FilterRule{ - SrcIPs: ipSetToPrefixStringList(srcSet), + SrcIPs: srcResolved.Strings(), DstPorts: destPorts, IPProto: ipp.Protocol.toIANAProtocolNumbers(), }) @@ -277,21 +281,23 @@ func (pol *Policy) compileGrantWithAutogroupSelf( if len(otherDests) > 0 { var srcIPs netipx.IPSetBuilder - for _, ips := range resolvedSrcIPs { - srcIPs.AddSet(ips) + for _, ips := range resolvedSrcs { + for _, pref := range ips.Prefixes() { + srcIPs.AddPrefix(pref) + } } - srcSet, err := srcIPs.IPSet() + srcResolved, err := newResolved(&srcIPs) if err != nil { return nil, err } - if srcSet != nil && len(srcSet.Prefixes()) > 0 { + if !srcResolved.Empty() { destPorts := pol.destinationsToNetPortRange(users, nodes, otherDests, ipp.Ports) if len(destPorts) > 0 { rules = append(rules, tailcfg.FilterRule{ - SrcIPs: ipSetToPrefixStringList(srcSet), + SrcIPs: srcResolved.Strings(), DstPorts: destPorts, IPProto: ipp.Protocol.toIANAProtocolNumbers(), }) @@ -474,7 +480,9 @@ func (pol *Policy) compileSSHPolicy( } if ips != nil { - dest.AddSet(ips) + for _, pref := range ips.Prefixes() { + dest.AddPrefix(pref) + } } } @@ -497,7 +505,7 @@ func (pol *Policy) compileSSHPolicy( appendRules(taggedPrincipals, 0, false) } } else { - if principals := ipSetToPrincipals(srcIPs); len(principals) > 0 { + if principals := resolvedAddrsToPrincipals(srcIPs); len(principals) > 0 { rules = append(rules, &tailcfg.SSHRule{ Principals: principals, SSHUsers: baseUserMap, @@ -506,7 +514,7 @@ func (pol *Policy) compileSSHPolicy( }) } } - } else if hasLocalpart && node.InIPSet(srcIPs) { + } else if hasLocalpart && slices.ContainsFunc(node.IPs(), srcIPs.Contains) { // Self-access: source node not in destination set // receives rules scoped to its own user. if node.IsTagged() { @@ -552,6 +560,23 @@ func (pol *Policy) compileSSHPolicy( }, nil } +// resolvedAddrsToPrincipals converts ResolvedAddresses into SSH principals, one per address. +func resolvedAddrsToPrincipals(addrs ResolvedAddresses) []*tailcfg.SSHPrincipal { + if addrs == nil { + return nil + } + + var principals []*tailcfg.SSHPrincipal + + for addr := range addrs.Iter() { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + return principals +} + // ipSetToPrincipals converts an IPSet into SSH principals, one per address. func ipSetToPrincipals(ipSet *netipx.IPSet) []*tailcfg.SSHPrincipal { if ipSet == nil { @@ -619,7 +644,7 @@ func resolveLocalparts( // Only includes nodes whose IPs are in the srcIPs set. func groupSourcesByUser( nodes views.Slice[types.NodeView], - srcIPs *netipx.IPSet, + srcIPs ResolvedAddresses, ) ([]uint, map[uint][]*tailcfg.SSHPrincipal, []*tailcfg.SSHPrincipal) { userIPSets := make(map[uint]*netipx.IPSetBuilder) diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 77de20eb..5852c410 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -1198,7 +1198,9 @@ func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.Node case Alias: // If it does not resolve, that means the tag is not associated with any IP addresses. resolved, _ := o.Resolve(p, users, nodes) - ips.AddSet(resolved) + for _, pref := range resolved.Prefixes() { + ips.AddPrefix(pref) + } default: // Should never happen - after flattening, all owners should be Alias types diff --git a/hscontrol/policy/v2/tailscale_grants_compat_test.go b/hscontrol/policy/v2/tailscale_grants_compat_test.go index a25f4df6..6930f660 100644 --- a/hscontrol/policy/v2/tailscale_grants_compat_test.go +++ b/hscontrol/policy/v2/tailscale_grants_compat_test.go @@ -238,14 +238,6 @@ var grantSkipReasons = map[string]string{ // tests in this category. // ======================================================================== - // J-series: Protocol-specific IP grants - "GRANT-J1": "SRCIPS_FORMAT", - "GRANT-J2": "SRCIPS_FORMAT", - "GRANT-J3": "SRCIPS_FORMAT", - "GRANT-J4": "SRCIPS_FORMAT", - "GRANT-J5": "SRCIPS_FORMAT", - "GRANT-J6": "SRCIPS_FORMAT", - // K-series: Various IP grant patterns "GRANT-K1": "SRCIPS_FORMAT", "GRANT-K2": "SRCIPS_FORMAT", @@ -265,122 +257,45 @@ var grantSkipReasons = map[string]string{ "GRANT-P01_3": "SRCIPS_FORMAT", "GRANT-P01_4": "SRCIPS_FORMAT", - // P02-series: Source targeting (user, group, tag) - "GRANT-P02_1": "SRCIPS_FORMAT", - "GRANT-P02_2": "SRCIPS_FORMAT", - "GRANT-P02_3": "SRCIPS_FORMAT", - "GRANT-P02_4": "SRCIPS_FORMAT", - "GRANT-P02_5_CORRECT": "SRCIPS_FORMAT", - "GRANT-P02_5_NAIVE": "SRCIPS_FORMAT", - - // P03-series: Destination targeting - "GRANT-P03_1": "SRCIPS_FORMAT", - "GRANT-P03_2": "SRCIPS_FORMAT", - "GRANT-P03_3": "SRCIPS_FORMAT", - "GRANT-P03_4": "SRCIPS_FORMAT", - - // P04-series: autogroup:member grants - "GRANT-P04_1": "SRCIPS_FORMAT", - "GRANT-P04_2": "SRCIPS_FORMAT", - // P05-series: Tag-to-tag grants "GRANT-P05_1": "SRCIPS_FORMAT", "GRANT-P05_2": "SRCIPS_FORMAT", "GRANT-P05_3": "SRCIPS_FORMAT", - // P06-series: IP protocol grants - "GRANT-P06_1": "SRCIPS_FORMAT", - "GRANT-P06_2": "SRCIPS_FORMAT", - "GRANT-P06_3": "SRCIPS_FORMAT", - "GRANT-P06_4": "SRCIPS_FORMAT", - "GRANT-P06_5": "SRCIPS_FORMAT", - "GRANT-P06_6": "SRCIPS_FORMAT", - "GRANT-P06_7": "SRCIPS_FORMAT", - // P08-series: Multiple grants / rule merging - "GRANT-P08_1": "SRCIPS_FORMAT", - "GRANT-P08_2": "SRCIPS_FORMAT", - "GRANT-P08_4": "SRCIPS_FORMAT", - "GRANT-P08_5": "SRCIPS_FORMAT", - "GRANT-P08_6": "SRCIPS_FORMAT", - "GRANT-P08_7": "SRCIPS_FORMAT", "GRANT-P08_8": "SRCIPS_FORMAT", // P09-series: ACL-to-grant conversion equivalence tests - "GRANT-P09_1A": "SRCIPS_FORMAT", - "GRANT-P09_1B": "SRCIPS_FORMAT", - "GRANT-P09_1C": "SRCIPS_FORMAT", - "GRANT-P09_1D": "SRCIPS_FORMAT", "GRANT-P09_1E": "SRCIPS_FORMAT", - "GRANT-P09_2A_CORRECT": "SRCIPS_FORMAT", - "GRANT-P09_2A_NAIVE": "SRCIPS_FORMAT", "GRANT-P09_2B_CORRECT": "SRCIPS_FORMAT", "GRANT-P09_2B_NAIVE": "SRCIPS_FORMAT", "GRANT-P09_2C": "SRCIPS_FORMAT", - "GRANT-P09_3A": "SRCIPS_FORMAT", - "GRANT-P09_3B": "SRCIPS_FORMAT", "GRANT-P09_3C": "SRCIPS_FORMAT", - "GRANT-P09_4A": "SRCIPS_FORMAT", - "GRANT-P09_4B": "SRCIPS_FORMAT", "GRANT-P09_4C": "SRCIPS_FORMAT", "GRANT-P09_4D": "SRCIPS_FORMAT", "GRANT-P09_4E": "SRCIPS_FORMAT", "GRANT-P09_4F": "SRCIPS_FORMAT", "GRANT-P09_4G": "SRCIPS_FORMAT", - "GRANT-P09_5A": "SRCIPS_FORMAT", - "GRANT-P09_5B": "SRCIPS_FORMAT", - "GRANT-P09_5C_NAIVE": "SRCIPS_FORMAT", "GRANT-P09_6A": "SRCIPS_FORMAT", - "GRANT-P09_6C": "SRCIPS_FORMAT", "GRANT-P09_6D": "SRCIPS_FORMAT", "GRANT-P09_7A": "SRCIPS_FORMAT", "GRANT-P09_7B_NAIVE": "SRCIPS_FORMAT", - "GRANT-P09_7C": "SRCIPS_FORMAT", "GRANT-P09_7D_NAIVE": "SRCIPS_FORMAT", - "GRANT-P09_8A": "SRCIPS_FORMAT", - "GRANT-P09_8B": "SRCIPS_FORMAT", "GRANT-P09_8C": "SRCIPS_FORMAT", - "GRANT-P09_9A": "SRCIPS_FORMAT", - "GRANT-P09_9B": "SRCIPS_FORMAT", - "GRANT-P09_9C": "SRCIPS_FORMAT", - "GRANT-P09_10A": "SRCIPS_FORMAT", - "GRANT-P09_10B": "SRCIPS_FORMAT", - "GRANT-P09_10C": "SRCIPS_FORMAT", - "GRANT-P09_10D": "SRCIPS_FORMAT", - "GRANT-P09_11A": "SRCIPS_FORMAT", "GRANT-P09_11B": "SRCIPS_FORMAT", - "GRANT-P09_11C_NAIVE": "SRCIPS_FORMAT", - "GRANT-P09_11D": "SRCIPS_FORMAT", - "GRANT-P09_12A": "SRCIPS_FORMAT", "GRANT-P09_12B": "SRCIPS_FORMAT", "GRANT-P09_13E": "SRCIPS_FORMAT", "GRANT-P09_13F": "SRCIPS_FORMAT", "GRANT-P09_13G": "SRCIPS_FORMAT", - "GRANT-P09_14A": "SRCIPS_FORMAT", - "GRANT-P09_14B": "SRCIPS_FORMAT", - "GRANT-P09_14C": "SRCIPS_FORMAT", - "GRANT-P09_14D": "SRCIPS_FORMAT", - "GRANT-P09_14E": "SRCIPS_FORMAT", - "GRANT-P09_14F": "SRCIPS_FORMAT", - "GRANT-P09_14G": "SRCIPS_FORMAT", - "GRANT-P09_14H": "SRCIPS_FORMAT", - "GRANT-P09_14I": "SRCIPS_FORMAT", // P10-series: Host alias grants - "GRANT-P10_1": "SRCIPS_FORMAT", - "GRANT-P10_2": "SRCIPS_FORMAT", "GRANT-P10_3": "SRCIPS_FORMAT", "GRANT-P10_4": "SRCIPS_FORMAT", - // P11-series: autogroup:tagged grants - "GRANT-P11_1": "SRCIPS_FORMAT", - "GRANT-P11_2": "SRCIPS_FORMAT", - // P13-series: CIDR destination grants "GRANT-P13_1": "SRCIPS_FORMAT", "GRANT-P13_2": "SRCIPS_FORMAT", "GRANT-P13_3": "SRCIPS_FORMAT", - "GRANT-P13_4": "SRCIPS_FORMAT", // P15-series: Empty/no-match grants "GRANT-P15_1": "SRCIPS_FORMAT", diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 090546f0..9375785f 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -3,10 +3,12 @@ package v2 import ( "errors" "fmt" + "iter" "net/netip" "slices" "strconv" "strings" + "sync" "time" "github.com/go-json-experiment/json" @@ -124,6 +126,89 @@ var ( ErrProtocolNoSpecificPorts = errors.New("protocol does not support specific ports") ) +type resolved struct { + ips netipx.IPSet +} + +func newResolved(ipb *netipx.IPSetBuilder) (resolved, error) { + ips, err := ipb.IPSet() + if err != nil { + return resolved{}, err + } + return resolved{ips: *ips}, nil +} + +func newResolvedAddresses(ips *netipx.IPSet, err error) (ResolvedAddresses, error) { + if err != nil { + return nil, err + } + if ips == nil { + return nil, nil + } + return resolved{ips: *ips}, nil +} + +func ipSetToStrings(ips *netipx.IPSet) []string { + var result []string + + for _, r := range ips.Ranges() { + if r.From() == r.To() { + result = append(result, r.From().String()) + continue + } + + if p, ok := r.Prefix(); ok { + result = append(result, p.String()) + continue + } + + result = append(result, r.String()) + } + + return result +} + +func (res resolved) Strings() []string { + return ipSetToStrings(&res.ips) +} + +func (res resolved) Prefixes() []netip.Prefix { + ret := res.ips.Prefixes() + + return ret +} + +func (res resolved) Empty() bool { + return len(res.ips.Prefixes()) == 0 +} + +func (res resolved) Iter() iter.Seq[netip.Addr] { + return util.IPSetAddrIter(&res.ips) +} + +func (res resolved) Contains(ip netip.Addr) bool { + return res.ips.Contains(ip) +} + +type ResolvedAddresses interface { + // Strings returns a slice of string representations of IP addresses, + // it will return the appropriate representation for the underlying Alias. + // Some should be returned as Prefixes and some as IP ranges. + Strings() []string + + // Prefixes returns a slice of netip.Prefix representations of IP addresses. + Prefixes() []netip.Prefix + + // Empty reports if there are no addresses in the ResolvedAddresses. + Empty() bool + + // Iter returns an iterator over netip.Addr representations of IP addresses. + Iter() iter.Seq[netip.Addr] + + // Contains reports if the given IP address is contained in the ResolvedAddresses. + Contains(ip netip.Addr) bool +} + type Asterix int func (a Asterix) Validate() error { @@ -194,16 +279,36 @@ func (a Asterix) UnmarshalJSON(b []byte) error { return nil } -func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder +var asterixResolved = sync.OnceValue(func() *netipx.IPSet { + var ipb netipx.IPSetBuilder + ipb.AddPrefix(tsaddr.TailscaleULARange()) + ipb.AddPrefix(tsaddr.CGNATRange()) + ipb.RemovePrefix(tsaddr.ChromeOSVMRange()) - // Use Tailscale's CGNAT range for IPv4 and ULA range for IPv6. - // This matches Tailscale's behavior where wildcard (*) refers to - // "any node in the tailnet" which uses these address ranges. - ips.AddPrefix(tsaddr.CGNATRange()) - ips.AddPrefix(tsaddr.TailscaleULARange()) + ips, err := ipb.IPSet() + if err != nil { + panic(fmt.Sprintf("failed to build IPSet for wildcard: %v", err)) + } - return ips.IPSet() + return ips +}) + +func (a Asterix) Resolve(p *Policy, u types.Users, n views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(a.resolve(p, u, n)) +} + +func (a Asterix) resolve(p *Policy, _ types.Users, _ views.Slice[types.NodeView]) (*netipx.IPSet, error) { + if pfxs := p.AutoApprovers.prefixes(); len(pfxs) > 0 { + var ipb netipx.IPSetBuilder + ipb.AddSet(asterixResolved()) + for _, pfx := range p.AutoApprovers.prefixes() { + ipb.AddPrefix(pfx) + } + + return ipb.IPSet() + } + + return asterixResolved(), nil } // Username is a string that represents a username, it must contain an @. @@ -286,7 +391,11 @@ func (u *Username) resolveUser(users types.Users) (types.User, error) { return potentialUsers[0], nil } -func (u *Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (u *Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(u.resolve(nil, users, nodes)) +} + +func (u *Username) resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error @@ -365,14 +474,18 @@ func (g *Group) MarshalJSON() ([]byte, error) { return json.Marshal(string(*g)) } -func (g *Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (g *Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(g.resolve(p, users, nodes)) +} + +func (g *Group) resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) for _, user := range p.Groups[*g] { - uips, err := user.Resolve(nil, users, nodes) + uips, err := user.resolve(nil, users, nodes) if err != nil { errs = append(errs, err) } @@ -405,7 +518,11 @@ func (t *Tag) UnmarshalJSON(b []byte) error { return nil } -func (t *Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (t *Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(t.resolve(p, users, nodes)) +} + +func (t *Tag) resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, node := range nodes.All() { @@ -457,7 +574,11 @@ func (h *Host) UnmarshalJSON(b []byte) error { return nil } -func (h *Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (h *Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(h.resolve(p, nil, nodes)) +} + +func (h *Host) resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error @@ -554,7 +675,11 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // of the Prefix and the Policy, Users, and Nodes. // // See [Policy], [types.Users], and [types.Nodes] for more details. -func (p *Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (p *Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(p.resolve(nil, nil, nodes)) +} + +func (p *Prefix) resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error @@ -629,7 +754,11 @@ func (ag *AutoGroup) MarshalJSON() ([]byte, error) { return json.Marshal(string(*ag)) } -func (ag *AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (ag *AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { + return newResolvedAddresses(ag.resolve(p, users, nodes)) +} + +func (ag *AutoGroup) resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder switch *ag { @@ -694,7 +823,9 @@ type Alias interface { // of the Alias and the Policy, Users and Nodes. // This is an interface definition and the implementation is independent of // the Alias type. - Resolve(pol *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) + Resolve(pol *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) + + resolve(pol *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -960,14 +1091,14 @@ func (a *Aliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a *Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (a *Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { var ( ips netipx.IPSetBuilder errs []error ) for _, alias := range *a { - aips, err := alias.Resolve(p, users, nodes) + aips, err := alias.resolve(p, users, nodes) if err != nil { errs = append(errs, err) } @@ -975,7 +1106,7 @@ func (a *Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types. ips.AddSet(aips) } - return buildIPSetMultiErr(&ips, errs) + return newResolvedAddresses(buildIPSetMultiErr(&ips, errs)) } func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.IPSet, error) { @@ -1378,6 +1509,21 @@ func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) { return json.Marshal(&obj) } +// prefixes returns the prefixes that have auto-approvers defined in the policy. +// It filters out exit routes since they are not associated with a specific prefix and are handled separately. +func (ap AutoApproverPolicy) prefixes() []netip.Prefix { + prefixes := make([]netip.Prefix, 0, len(ap.Routes)) + + for prefix := range ap.Routes { + if tsaddr.IsExitRoute(prefix) { + continue + } + prefixes = append(prefixes, prefix) + } + + return prefixes +} + // resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet. // The resulting map can be used to quickly look up if a node can self-approve a route. // It is intended for internal use in a PolicyManager. @@ -1402,7 +1548,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. - ips, _ := aa.Resolve(p, users, nodes) + ips, _ := aa.resolve(p, users, nodes) routes[prefix].AddSet(ips) } } @@ -1417,7 +1563,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. - ips, _ := aa.Resolve(p, users, nodes) + ips, _ := aa.resolve(p, users, nodes) exitNodeSetBuilder.AddSet(ips) } } @@ -2601,14 +2747,14 @@ func (a *SSHSrcAliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a *SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (a *SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (ResolvedAddresses, error) { var ( ips netipx.IPSetBuilder errs []error ) for _, alias := range *a { - aips, err := alias.Resolve(p, users, nodes) + aips, err := alias.resolve(p, users, nodes) if err != nil { errs = append(errs, err) } @@ -2616,7 +2762,7 @@ func (a *SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[ ips.AddSet(aips) } - return buildIPSetMultiErr(&ips, errs) + return newResolvedAddresses(buildIPSetMultiErr(&ips, errs)) } // SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.