diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 301fb65a..6805d9da 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/policyutil" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" @@ -1495,12 +1496,16 @@ func TestTagUserMutualExclusivity(t *testing.T) { require.NoError(t, err) // User1's user-owned node should have no rules reaching tagged nodes - // since there is no explicit user→tag ACL rule. + // since there is no explicit user→tag ACL rule. ReduceFilterRules + // filters compiled rules to only those where the node is a destination, + // matching the production pipeline in filterForNodeLocked. userNode := nodes[0].View() - userRules, err := pol.compileFilterRulesForNode(users, userNode, nodes.ViewSlice()) + compiled, err := pol.compileFilterRulesForNode(users, userNode, nodes.ViewSlice()) require.NoError(t, err) + userRules := policyutil.ReduceFilterRules(userNode, compiled) + for _, rule := range userRules { for _, dst := range rule.DstPorts { ipSet, parseErr := util.ParseIPSet(dst.IP, nil) @@ -1516,27 +1521,29 @@ func TestTagUserMutualExclusivity(t *testing.T) { } } - // Tag:server should be able to reach tag:database via the tag-to-tag rule. - taggedNode := nodes[2].View() + // Tag:database should receive the tag:server → tag:database rule after reduction. + dbNode := nodes[3].View() - taggedRules, err := pol.compileFilterRulesForNode(users, taggedNode, nodes.ViewSlice()) + compiled, err = pol.compileFilterRulesForNode(users, dbNode, nodes.ViewSlice()) require.NoError(t, err) - foundDatabaseDest := false + dbRules := policyutil.ReduceFilterRules(dbNode, compiled) - for _, rule := range taggedRules { - for _, dst := range rule.DstPorts { - ipSet, parseErr := util.ParseIPSet(dst.IP, nil) + foundServerSrc := false + + for _, rule := range dbRules { + for _, srcEntry := range rule.SrcIPs { + ipSet, parseErr := util.ParseIPSet(srcEntry, nil) require.NoError(t, parseErr) - if ipSet.Contains(netip.MustParseAddr("100.64.0.11")) { - foundDatabaseDest = true + if ipSet.Contains(netip.MustParseAddr("100.64.0.10")) { + foundServerSrc = true break } } } - assert.True(t, foundDatabaseDest, "tag:server should reach tag:database") + assert.True(t, foundServerSrc, "tag:database should accept traffic from tag:server") } // TestUserToTagCrossIdentityGrant tests that an explicit ACL rule granting @@ -1584,13 +1591,16 @@ func TestUserToTagCrossIdentityGrant(t *testing.T) { err := pol.validate() require.NoError(t, err) - // Compile rules for the tag:server node — it is the destination, - // so the filter should include user1's IP as source. + // Compile and reduce rules for the tag:server node — it is the + // destination, so after ReduceFilterRules, the filter should include + // user1's IP as source. taggedNode := nodes[2].View() - rules, err := pol.compileFilterRulesForNode(users, taggedNode, nodes.ViewSlice()) + compiled, err := pol.compileFilterRulesForNode(users, taggedNode, nodes.ViewSlice()) require.NoError(t, err) + rules := policyutil.ReduceFilterRules(taggedNode, compiled) + // user1's IP should appear as a source that can reach tag:server. foundUser1Src := false @@ -2320,23 +2330,12 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) { for _, rule := range rules { for _, dp := range rule.DstPorts { - // DstPort IPs may be bare addresses or CIDR prefixes - pref, err := netip.ParsePrefix(dp.IP) + ipSet, err := util.ParseIPSet(dp.IP, nil) if err != nil { - // Try as bare address - a, err2 := netip.ParseAddr(dp.IP) - if err2 != nil { - continue - } - - if a == addr { - return true - } - continue } - if pref.Contains(addr) { + if ipSet.Contains(addr) { return true } } @@ -2350,21 +2349,12 @@ func TestAutogroupSelfWithNonExistentUserInGroup(t *testing.T) { for _, rule := range rules { for _, srcIP := range rule.SrcIPs { - pref, err := netip.ParsePrefix(srcIP) + ipSet, err := util.ParseIPSet(srcIP, nil) if err != nil { - a, err2 := netip.ParseAddr(srcIP) - if err2 != nil { - continue - } - - if a == addr { - return true - } - continue } - if pref.Contains(addr) { + if ipSet.Contains(addr) { return true } }