diff --git a/hscontrol/noise.go b/hscontrol/noise.go index b5d41b5b..cd2f2036 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/url" + "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -331,6 +332,7 @@ func (ns *noiseServer) SSHActionHandler( action, err := ns.sshAction( reqLog, + srcNodeID, dstNodeID, req.URL.Query().Get("auth_id"), ) if err != nil { @@ -356,16 +358,18 @@ func (ns *noiseServer) SSHActionHandler( } // sshAction resolves the SSH action for the given request parameters. -// It returns the action to send to the client, or an HTTPError on -// failure. +// It returns the action to send to the client, or an HTTPError on failure. // -// Two cases: -// 1. Initial request — build a HoldAndDelegate URL and wait for the -// user to authenticate. -// 2. Follow-up request — an auth_id is present, wait for the auth +// Three cases: +// 1. Initial request, auto-approved — source recently authenticated +// within the check period, accept immediately. +// 2. Initial request, needs auth — build a HoldAndDelegate URL and +// wait for the user to authenticate. +// 3. Follow-up request — an auth_id is present, wait for the auth // verdict and accept or reject. func (ns *noiseServer) sshAction( reqLog zerolog.Logger, + srcNodeID, dstNodeID types.NodeID, authIDStr string, ) (*tailcfg.SSHAction, error) { action := tailcfg.SSHAction{ @@ -374,14 +378,38 @@ func (ns *noiseServer) sshAction( AllowRemotePortForwarding: true, } + // Look up check params from the server's own policy rather than + // trusting URL parameters, which the client could tamper with. + checkPeriod, checkFound := ns.headscale.state.SSHCheckParams( + srcNodeID, dstNodeID, + ) + // Follow-up request with auth_id — wait for the auth verdict. if authIDStr != "" { return ns.sshActionFollowUp( reqLog, &action, authIDStr, + srcNodeID, dstNodeID, + checkFound, ) } - // Initial request — create an auth session and hold. + // Initial request — check if auto-approval applies. + if checkFound && checkPeriod > 0 { + if lastAuth, ok := ns.headscale.state.GetLastSSHAuth( + srcNodeID, dstNodeID, + ); ok && time.Since(lastAuth) < checkPeriod { + reqLog.Trace().Caller(). + Dur("check_period", checkPeriod). + Time("last_auth", lastAuth). + Msg("auto-approved within check period") + + action.Accept = true + + return &action, nil + } + } + + // No auto-approval — create an auth session and hold. return ns.sshActionHoldAndDelegate(reqLog, &action) } @@ -445,6 +473,8 @@ func (ns *noiseServer) sshActionFollowUp( reqLog zerolog.Logger, action *tailcfg.SSHAction, authIDStr string, + srcNodeID, dstNodeID types.NodeID, + checkFound bool, ) (*tailcfg.SSHAction, error) { authID, err := types.AuthIDFromString(authIDStr) if err != nil { @@ -481,6 +511,14 @@ func (ns *noiseServer) sshActionFollowUp( action.Accept = true + // Record the successful auth for future auto-approval. + if checkFound { + ns.headscale.state.SetLastSSHAuth(srcNodeID, dstNodeID) + + reqLog.Trace().Caller(). + Msg("auth recorded for auto-approval") + } + return action, nil } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 2de2e8dd..0c69160f 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -2,6 +2,7 @@ package policy import ( "net/netip" + "time" "github.com/juanfont/headscale/hscontrol/policy/matcher" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" @@ -20,6 +21,9 @@ type PolicyManager interface { // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) + // SSHCheckParams resolves the SSH check period for a (src, dst) pair + // from the current policy, avoiding trust of client-provided URL params. + SSHCheckParams(srcNodeID, dstNodeID types.NodeID) (time.Duration, bool) SetPolicy(pol []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index c8c515cd..9df62525 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -327,7 +327,23 @@ var sshAccept = tailcfg.SSHAction{ AllowRemotePortForwarding: true, } +// checkPeriodFromRule extracts the check period duration from an SSH rule. +// Returns SSHCheckPeriodDefault if no checkPeriod is configured, +// 0 if checkPeriod is "always", or the configured duration otherwise. +func checkPeriodFromRule(rule SSH) time.Duration { + switch { + case rule.CheckPeriod == nil: + return SSHCheckPeriodDefault + case rule.CheckPeriod.Always: + return 0 + default: + return rule.CheckPeriod.Duration + } +} + func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { + holdURL := baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER" + return tailcfg.SSHAction{ Reject: false, Accept: false, @@ -339,7 +355,7 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { // * $DST_NODE_ID (Node.ID as int64 string) // * $SSH_USER (URL escaped, ssh user requested) // * $LOCAL_USER (URL escaped, local user mapped) - HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", + HoldAndDelegate: holdURL, AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -396,7 +412,7 @@ func (pol *Policy) compileSSHPolicy( case SSHActionAccept: action = sshAccept case SSHActionCheck: - action = sshCheck(baseURL, time.Duration(rule.CheckPeriod)) + action = sshCheck(baseURL, checkPeriodFromRule(rule)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 01d3d71d..da76e0f8 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -10,7 +10,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" - "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" @@ -680,7 +679,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { SSHs: []SSH{ { Action: "check", - CheckPeriod: model.Duration(24 * time.Hour), + CheckPeriod: &SSHCheckPeriod{Duration: 24 * time.Hour}, Sources: SSHSrcAliases{gp("group:admins")}, Destinations: SSHDstAliases{tp("tag:server")}, Users: []SSHUser{"ssh-it-user"}, @@ -710,6 +709,10 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { assert.NotEmpty(t, rule.Action.HoldAndDelegate) assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/") assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration) + + // Verify check params are NOT encoded in the URL (looked up server-side). + assert.NotContains(t, rule.Action.HoldAndDelegate, "check_explicit") + assert.NotContains(t, rule.Action.HoldAndDelegate, "check_period") } // TestCompileSSHPolicy_CheckBeforeAcceptOrdering verifies that check @@ -754,7 +757,7 @@ func TestCompileSSHPolicy_CheckBeforeAcceptOrdering(t *testing.T) { }, { Action: "check", - CheckPeriod: model.Duration(24 * time.Hour), + CheckPeriod: &SSHCheckPeriod{Duration: 24 * time.Hour}, Sources: SSHSrcAliases{gp("group:admins")}, Destinations: SSHDstAliases{tp("tag:server")}, Users: []SSHUser{"ssh-it-user"}, @@ -2167,3 +2170,257 @@ func TestMergeFilterRules(t *testing.T) { }) } } + +func TestCompileSSHPolicy_CheckPeriodVariants(t *testing.T) { + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + } + + node := types.Node{ + Hostname: "device", + IPv4: createAddr("100.64.0.1"), + UserID: new(users[0].ID), + User: new(users[0]), + } + + nodes := types.Nodes{&node} + + tests := []struct { + name string + checkPeriod *SSHCheckPeriod + wantDuration time.Duration + }{ + { + name: "nil period defaults to 12h", + checkPeriod: nil, + wantDuration: SSHCheckPeriodDefault, + }, + { + name: "always period uses 0", + checkPeriod: &SSHCheckPeriod{Always: true}, + wantDuration: 0, + }, + { + name: "explicit 2h", + checkPeriod: &SSHCheckPeriod{Duration: 2 * time.Hour}, + wantDuration: 2 * time.Hour, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy := &Policy{ + SSHs: []SSH{ + { + Action: SSHActionCheck, + Sources: SSHSrcAliases{up("user1@")}, + Destinations: SSHDstAliases{agp("autogroup:member")}, + Users: SSHUsers{"root"}, + CheckPeriod: tt.checkPeriod, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + sshPolicy, err := policy.compileSSHPolicy( + "http://test", + users, + node.View(), + nodes.ViewSlice(), + ) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + assert.Equal(t, tt.wantDuration, rule.Action.SessionDuration) + // Check params must NOT be in the URL; they are + // resolved server-side via SSHCheckParams. + assert.NotContains(t, rule.Action.HoldAndDelegate, "check_explicit") + assert.NotContains(t, rule.Action.HoldAndDelegate, "check_period") + }) + } +} + +func TestSSHCheckParams(t *testing.T) { + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + } + + nodeUser1 := types.Node{ + ID: 1, + Hostname: "user1-device", + IPv4: createAddr("100.64.0.1"), + UserID: new(users[0].ID), + User: new(users[0]), + } + nodeUser2 := types.Node{ + ID: 2, + Hostname: "user2-device", + IPv4: createAddr("100.64.0.2"), + UserID: new(users[1].ID), + User: new(users[1]), + } + nodeTaggedServer := types.Node{ + ID: 3, + Hostname: "tagged-server", + IPv4: createAddr("100.64.0.3"), + UserID: new(users[0].ID), + User: new(users[0]), + Tags: []string{"tag:server"}, + } + + nodes := types.Nodes{&nodeUser1, &nodeUser2, &nodeTaggedServer} + + tests := []struct { + name string + policy []byte + srcID types.NodeID + dstID types.NodeID + wantPeriod time.Duration + wantOK bool + }{ + { + name: "explicit check period for tagged destination", + policy: []byte(`{ + "tagOwners": {"tag:server": ["user1@"]}, + "ssh": [{ + "action": "check", + "checkPeriod": "2h", + "src": ["user2@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(2), + dstID: types.NodeID(3), + wantPeriod: 2 * time.Hour, + wantOK: true, + }, + { + name: "default period when checkPeriod omitted", + policy: []byte(`{ + "tagOwners": {"tag:server": ["user1@"]}, + "ssh": [{ + "action": "check", + "src": ["user2@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(2), + dstID: types.NodeID(3), + wantPeriod: SSHCheckPeriodDefault, + wantOK: true, + }, + { + name: "always check (checkPeriod always)", + policy: []byte(`{ + "tagOwners": {"tag:server": ["user1@"]}, + "ssh": [{ + "action": "check", + "checkPeriod": "always", + "src": ["user2@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(2), + dstID: types.NodeID(3), + wantPeriod: 0, + wantOK: true, + }, + { + name: "no match when src not in rule", + policy: []byte(`{ + "tagOwners": {"tag:server": ["user1@"]}, + "ssh": [{ + "action": "check", + "src": ["user1@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(2), + dstID: types.NodeID(3), + wantOK: false, + }, + { + name: "no match when dst not in rule", + policy: []byte(`{ + "tagOwners": {"tag:server": ["user1@"]}, + "ssh": [{ + "action": "check", + "src": ["user2@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(2), + dstID: types.NodeID(1), + wantOK: false, + }, + { + name: "accept rule is not returned", + policy: []byte(`{ + "tagOwners": {"tag:server": ["user1@"]}, + "ssh": [{ + "action": "accept", + "src": ["user2@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(2), + dstID: types.NodeID(3), + wantOK: false, + }, + { + name: "autogroup:self matches same-user pair", + policy: []byte(`{ + "ssh": [{ + "action": "check", + "checkPeriod": "6h", + "src": ["user1@"], + "dst": ["autogroup:self"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(1), + dstID: types.NodeID(1), + wantPeriod: 6 * time.Hour, + wantOK: true, + }, + { + name: "autogroup:self rejects cross-user pair", + policy: []byte(`{ + "ssh": [{ + "action": "check", + "src": ["user1@"], + "dst": ["autogroup:self"], + "users": ["autogroup:nonroot"] + }] + }`), + srcID: types.NodeID(1), + dstID: types.NodeID(2), + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm, err := NewPolicyManager(tt.policy, users, nodes.ViewSlice()) + require.NoError(t, err) + + period, ok := pm.SSHCheckParams(tt.srcID, tt.dstID) + assert.Equal(t, tt.wantOK, ok, "ok mismatch") + + if tt.wantOK { + assert.Equal(t, tt.wantPeriod, period, "period mismatch") + } + }) + } +} diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 744f52c7..77de20eb 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -9,6 +9,7 @@ import ( "slices" "strings" "sync" + "time" "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/policyutil" @@ -240,6 +241,84 @@ func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcf return sshPol, nil } +// SSHCheckParams resolves the SSH check period for a source-destination +// node pair by looking up the current policy. This avoids trusting URL +// parameters that a client could tamper with. +// It returns the check period duration and whether a matching check +// rule was found. +func (pm *PolicyManager) SSHCheckParams( + srcNodeID, dstNodeID types.NodeID, +) (time.Duration, bool) { + pm.mu.Lock() + defer pm.mu.Unlock() + + if pm.pol == nil || len(pm.pol.SSHs) == 0 { + return 0, false + } + + // Find the source and destination node views. + var srcNode, dstNode types.NodeView + + for _, n := range pm.nodes.All() { + nid := n.ID() + if nid == srcNodeID { + srcNode = n + } + + if nid == dstNodeID { + dstNode = n + } + + if srcNode.Valid() && dstNode.Valid() { + break + } + } + + if !srcNode.Valid() || !dstNode.Valid() { + return 0, false + } + + // Iterate SSH rules to find the first matching check rule. + for _, rule := range pm.pol.SSHs { + if rule.Action != SSHActionCheck { + continue + } + + // Resolve sources and check if src node matches. + srcIPs, err := rule.Sources.Resolve(pm.pol, pm.users, pm.nodes) + if err != nil || srcIPs == nil { + continue + } + + if !slices.ContainsFunc(srcNode.IPs(), srcIPs.Contains) { + continue + } + + // Check if dst node matches any destination. + for _, dst := range rule.Destinations { + if ag, isAG := dst.(*AutoGroup); isAG && ag.Is(AutoGroupSelf) { + if !srcNode.IsTagged() && !dstNode.IsTagged() && + srcNode.User().ID() == dstNode.User().ID() { + return checkPeriodFromRule(rule), true + } + + continue + } + + dstIPs, err := dst.Resolve(pm.pol, pm.users, pm.nodes) + if err != nil || dstIPs == nil { + continue + } + + if slices.ContainsFunc(dstNode.IPs(), dstIPs.Contains) { + return checkPeriodFromRule(rule), true + } + } + } + + return 0, false +} + func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { if len(polB) == 0 { return false, nil diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 8785bed0..1596e09c 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -7,6 +7,7 @@ import ( "slices" "strconv" "strings" + "time" "github.com/go-json-experiment/json" "github.com/juanfont/headscale/hscontrol/types" @@ -43,6 +44,17 @@ var ( ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged") ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)") ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination") + ErrSSHCheckPeriodBelowMin = errors.New("checkPeriod below minimum of 1 minute") + ErrSSHCheckPeriodAboveMax = errors.New("checkPeriod above maximum of 168 hours (1 week)") + ErrSSHCheckPeriodOnNonCheck = errors.New("checkPeriod is only valid with action \"check\"") +) + +// SSH check period constants per Tailscale docs: +// https://tailscale.com/kb/1193/tailscale-ssh +const ( + SSHCheckPeriodDefault = 12 * time.Hour + SSHCheckPeriodMin = time.Minute + SSHCheckPeriodMax = 168 * time.Hour ) // ACL validation errors. @@ -2019,6 +2031,19 @@ func (p *Policy) validate() error { if err != nil { errs = append(errs, err) } + + // Validate checkPeriod + if ssh.CheckPeriod != nil { + switch { + case ssh.Action != SSHActionCheck: + errs = append(errs, ErrSSHCheckPeriodOnNonCheck) + default: + err := ssh.CheckPeriod.Validate() + if err != nil { + errs = append(errs, err) + } + } + } } for _, tagOwners := range p.TagOwners { @@ -2097,13 +2122,75 @@ func (p *Policy) validate() error { return nil } +// SSHCheckPeriod represents the check period for SSH "check" mode rules. +// nil means not specified (runtime default of 12h applies). +// Always=true means "always" (check on every request). +// Duration is an explicit period (min 1m, max 168h). +type SSHCheckPeriod struct { + Always bool + Duration time.Duration +} + +// UnmarshalJSON implements JSON unmarshaling for SSHCheckPeriod. +func (p *SSHCheckPeriod) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + if str == "always" { + p.Always = true + + return nil + } + + d, err := model.ParseDuration(str) + if err != nil { + return fmt.Errorf("parsing checkPeriod %q: %w", str, err) + } + + p.Duration = time.Duration(d) + + return nil +} + +// MarshalJSON implements JSON marshaling for SSHCheckPeriod. +func (p SSHCheckPeriod) MarshalJSON() ([]byte, error) { + if p.Always { + return []byte(`"always"`), nil + } + + return fmt.Appendf(nil, "%q", p.Duration.String()), nil +} + +// Validate checks that the SSHCheckPeriod is within allowed bounds. +func (p *SSHCheckPeriod) Validate() error { + if p.Always { + return nil + } + + if p.Duration < SSHCheckPeriodMin { + return fmt.Errorf( + "%w: got %s", + ErrSSHCheckPeriodBelowMin, + p.Duration, + ) + } + + if p.Duration > SSHCheckPeriodMax { + return fmt.Errorf( + "%w: got %s", + ErrSSHCheckPeriodAboveMax, + p.Duration, + ) + } + + return nil +} + // SSH controls who can ssh into which machines. type SSH struct { - Action SSHAction `json:"action"` - Sources SSHSrcAliases `json:"src"` - Destinations SSHDstAliases `json:"dst"` - Users SSHUsers `json:"users"` - CheckPeriod model.Duration `json:"checkPeriod,omitempty"` + Action SSHAction `json:"action"` + Sources SSHSrcAliases `json:"src"` + Destinations SSHDstAliases `json:"dst"` + Users SSHUsers `json:"users"` + CheckPeriod *SSHCheckPeriod `json:"checkPeriod,omitempty"` } // SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 6cbc7822..a68259a3 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -11,7 +11,6 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go4.org/netipx" @@ -711,7 +710,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, "ssh": [ { - "action": "accept", + "action": "check", "src": [ "group:admins" ], @@ -730,7 +729,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, SSHs: []SSH{ { - Action: "accept", + Action: "check", Sources: SSHSrcAliases{ gp("group:admins"), }, @@ -740,7 +739,7 @@ func TestUnmarshalPolicy(t *testing.T) { Users: []SSHUser{ SSHUser("root"), }, - CheckPeriod: model.Duration(24 * time.Hour), + CheckPeriod: &SSHCheckPeriod{Duration: 24 * time.Hour}, }, }, }, @@ -3827,3 +3826,218 @@ func TestFlattenTagOwners(t *testing.T) { }) } } + +func TestSSHCheckPeriodUnmarshal(t *testing.T) { + tests := []struct { + name string + input string + want *SSHCheckPeriod + wantErr bool + }{ + { + name: "always", + input: `"always"`, + want: &SSHCheckPeriod{Always: true}, + }, + { + name: "1h", + input: `"1h"`, + want: &SSHCheckPeriod{Duration: time.Hour}, + }, + { + name: "30m", + input: `"30m"`, + want: &SSHCheckPeriod{Duration: 30 * time.Minute}, + }, + { + name: "168h", + input: `"168h"`, + want: &SSHCheckPeriod{Duration: 168 * time.Hour}, + }, + { + name: "invalid", + input: `"notaduration"`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got SSHCheckPeriod + + err := json.Unmarshal([]byte(tt.input), &got) + if tt.wantErr { + require.Error(t, err) + + return + } + + require.NoError(t, err) + assert.Equal(t, *tt.want, got) + }) + } +} + +func TestSSHCheckPeriodRoundTrip(t *testing.T) { + tests := []struct { + name string + input SSHCheckPeriod + }{ + { + name: "always", + input: SSHCheckPeriod{Always: true}, + }, + { + name: "2h", + input: SSHCheckPeriod{Duration: 2 * time.Hour}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.input) + require.NoError(t, err) + + var got SSHCheckPeriod + + err = json.Unmarshal(data, &got) + require.NoError(t, err) + + assert.Equal(t, tt.input, got) + }) + } +} + +func TestSSHCheckPeriodNilInSSH(t *testing.T) { + input := `{ + "action": "check", + "src": ["user@"], + "dst": ["autogroup:member"], + "users": ["root"] + }` + + var ssh SSH + + err := json.Unmarshal([]byte(input), &ssh) + require.NoError(t, err) + assert.Nil(t, ssh.CheckPeriod) +} + +func TestSSHCheckPeriodValidate(t *testing.T) { + tests := []struct { + name string + period SSHCheckPeriod + wantErr error + }{ + { + name: "always is valid", + period: SSHCheckPeriod{Always: true}, + }, + { + name: "1m minimum valid", + period: SSHCheckPeriod{Duration: time.Minute}, + }, + { + name: "168h maximum valid", + period: SSHCheckPeriod{Duration: 168 * time.Hour}, + }, + { + name: "30s below minimum", + period: SSHCheckPeriod{Duration: 30 * time.Second}, + wantErr: ErrSSHCheckPeriodBelowMin, + }, + { + name: "169h above maximum", + period: SSHCheckPeriod{Duration: 169 * time.Hour}, + wantErr: ErrSSHCheckPeriodAboveMax, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.period.Validate() + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + + return + } + + require.NoError(t, err) + }) + } +} + +func TestSSHCheckPeriodPolicyValidation(t *testing.T) { + tests := []struct { + name string + ssh SSH + wantErr error + }{ + { + name: "check with nil period is valid", + ssh: SSH{ + Action: SSHActionCheck, + Sources: SSHSrcAliases{up("user@")}, + Destinations: SSHDstAliases{agp("autogroup:member")}, + Users: SSHUsers{"root"}, + }, + }, + { + name: "check with always is valid", + ssh: SSH{ + Action: SSHActionCheck, + Sources: SSHSrcAliases{up("user@")}, + Destinations: SSHDstAliases{agp("autogroup:member")}, + Users: SSHUsers{"root"}, + CheckPeriod: &SSHCheckPeriod{Always: true}, + }, + }, + { + name: "check with 1h is valid", + ssh: SSH{ + Action: SSHActionCheck, + Sources: SSHSrcAliases{up("user@")}, + Destinations: SSHDstAliases{agp("autogroup:member")}, + Users: SSHUsers{"root"}, + CheckPeriod: &SSHCheckPeriod{Duration: time.Hour}, + }, + }, + { + name: "accept with checkPeriod is invalid", + ssh: SSH{ + Action: SSHActionAccept, + Sources: SSHSrcAliases{up("user@")}, + Destinations: SSHDstAliases{agp("autogroup:member")}, + Users: SSHUsers{"root"}, + CheckPeriod: &SSHCheckPeriod{Duration: time.Hour}, + }, + wantErr: ErrSSHCheckPeriodOnNonCheck, + }, + { + name: "check with 30s is invalid", + ssh: SSH{ + Action: SSHActionCheck, + Sources: SSHSrcAliases{up("user@")}, + Destinations: SSHDstAliases{agp("autogroup:member")}, + Users: SSHUsers{"root"}, + CheckPeriod: &SSHCheckPeriod{Duration: 30 * time.Second}, + }, + wantErr: ErrSSHCheckPeriodBelowMin, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := &Policy{SSHs: []SSH{tt.ssh}} + err := pol.validate() + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + + return + } + + require.NoError(t, err) + }) + } +} diff --git a/hscontrol/state/ssh_check_test.go b/hscontrol/state/ssh_check_test.go new file mode 100644 index 00000000..04b9d6b1 --- /dev/null +++ b/hscontrol/state/ssh_check_test.go @@ -0,0 +1,103 @@ +package state + +import ( + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestStateForSSHCheck() *State { + return &State{ + sshCheckAuth: make(map[sshCheckPair]time.Time), + } +} + +func TestSSHCheckAuth(t *testing.T) { + s := newTestStateForSSHCheck() + + src := types.NodeID(1) + dst := types.NodeID(2) + otherDst := types.NodeID(3) + otherSrc := types.NodeID(4) + + // No record initially + _, ok := s.GetLastSSHAuth(src, dst) + require.False(t, ok) + + // Record auth for (src, dst) + s.SetLastSSHAuth(src, dst) + + // Same src+dst: found + authTime, ok := s.GetLastSSHAuth(src, dst) + require.True(t, ok) + assert.WithinDuration(t, time.Now(), authTime, time.Second) + + // Same src, different dst: not found (auth is per-pair) + _, ok = s.GetLastSSHAuth(src, otherDst) + require.False(t, ok) + + // Different src: not found + _, ok = s.GetLastSSHAuth(otherSrc, dst) + require.False(t, ok) +} + +func TestSSHCheckAuthClear(t *testing.T) { + s := newTestStateForSSHCheck() + + s.SetLastSSHAuth(types.NodeID(1), types.NodeID(2)) + s.SetLastSSHAuth(types.NodeID(1), types.NodeID(3)) + + _, ok := s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2)) + require.True(t, ok) + + _, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3)) + require.True(t, ok) + + // Clear + s.ClearSSHCheckAuth() + + _, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2)) + require.False(t, ok) + + _, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3)) + require.False(t, ok) +} + +func TestSSHCheckAuthConcurrent(t *testing.T) { + s := newTestStateForSSHCheck() + + var wg sync.WaitGroup + + for i := range 100 { + wg.Go(func() { + src := types.NodeID(uint64(i % 10)) //nolint:gosec + dst := types.NodeID(uint64(i%5 + 10)) //nolint:gosec + + s.SetLastSSHAuth(src, dst) + s.GetLastSSHAuth(src, dst) + }) + } + + wg.Wait() + + // Clear concurrently with reads + wg.Add(2) + + go func() { + defer wg.Done() + + s.ClearSSHCheckAuth() + }() + + go func() { + defer wg.Done() + + s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2)) + }() + + wg.Wait() +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index f2ae99a9..c5c917fa 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -67,6 +67,13 @@ var ErrNodeNameNotUnique = errors.New("node name is not unique") // ErrRegistrationExpired is returned when a registration has expired. var ErrRegistrationExpired = errors.New("registration expired") +// sshCheckPair identifies a (source, destination) node pair for +// SSH check auth tracking. +type sshCheckPair struct { + Src types.NodeID + Dst types.NodeID +} + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -91,6 +98,25 @@ type State struct { // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes + + // sshCheckAuth tracks when source nodes last completed SSH check auth. + // + // For rules without explicit checkPeriod (default 12h), auth covers any + // destination — keyed by (src, Dst=0) where 0 is a sentinel meaning "any". + // Ref: "Once re-authenticated to a destination, the user can access the + // device and any other device in the tailnet without re-verification + // for the next 12 hours." — https://tailscale.com/kb/1193/tailscale-ssh + // + // For rules with explicit checkPeriod, auth covers only that specific + // destination — keyed by (src, dst). + // Ref: "If a different check period is specified for the connection, + // then the user can access specifically this device without + // re-verification for the duration of the check period." + // + // Ref: https://github.com/tailscale/tailscale/issues/10480 + // Ref: https://github.com/tailscale/tailscale/issues/7125 + sshCheckAuth map[sshCheckPair]time.Time + sshCheckMu sync.RWMutex } // NewState creates and initializes a new State instance, setting up the database, @@ -189,6 +215,8 @@ func NewState(cfg *types.Config) (*State, error) { authCache: authCache, primaryRoutes: routes.New(), nodeStore: nodeStore, + + sshCheckAuth: make(map[sshCheckPair]time.Time), }, nil } @@ -227,6 +255,10 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { return nil, fmt.Errorf("setting policy: %w", err) } + // Clear SSH check auth times when policy changes to ensure stale + // approvals don't persist if checkPeriod rules are modified or removed. + s.ClearSSHCheckAuth() + // Rebuild peer maps after policy changes because the peersFunc in NodeStore // uses the PolicyManager's filters. Without this, nodes won't see newly allowed // peers until a node is added/removed, causing autogroup:self policies to not @@ -874,6 +906,14 @@ func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { return s.polMan.SSHPolicy(s.cfg.ServerURL, node) } +// SSHCheckParams resolves the SSH check period for a source-destination +// node pair from the current policy. +func (s *State) SSHCheckParams( + srcNodeID, dstNodeID types.NodeID, +) (time.Duration, bool) { + return s.polMan.SSHCheckParams(srcNodeID, dstNodeID) +} + // Filter returns the current network filter rules and matches. func (s *State) Filter() ([]tailcfg.FilterRule, []matcher.Match) { return s.polMan.Filter() @@ -896,7 +936,15 @@ func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool { // SetPolicy updates the policy configuration. func (s *State) SetPolicy(pol []byte) (bool, error) { - return s.polMan.SetPolicy(pol) + changed, err := s.polMan.SetPolicy(pol) + if err != nil { + return changed, err + } + + // Clear SSH check auth times when policy changes. + s.ClearSSHCheckAuth() + + return changed, nil } // AutoApproveRoutes checks if a node's routes should be auto-approved. @@ -1077,6 +1125,35 @@ func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) { s.authCache.Set(id, entry) } +// SetLastSSHAuth records a successful SSH check authentication +// for the given (src, dst) node pair. +func (s *State) SetLastSSHAuth(src, dst types.NodeID) { + s.sshCheckMu.Lock() + defer s.sshCheckMu.Unlock() + + s.sshCheckAuth[sshCheckPair{Src: src, Dst: dst}] = time.Now() +} + +// GetLastSSHAuth returns when src last authenticated for SSH check +// to dst. +func (s *State) GetLastSSHAuth(src, dst types.NodeID) (time.Time, bool) { + s.sshCheckMu.RLock() + defer s.sshCheckMu.RUnlock() + + t, ok := s.sshCheckAuth[sshCheckPair{Src: src, Dst: dst}] + + return t, ok +} + +// ClearSSHCheckAuth clears all recorded SSH check auth times. +// Called when the policy changes to ensure stale auth times don't grant access. +func (s *State) ClearSSHCheckAuth() { + s.sshCheckMu.Lock() + defer s.sshCheckMu.Unlock() + + s.sshCheckAuth = make(map[sshCheckPair]time.Time) +} + // logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. func logHostinfoValidation(nv types.NodeView, username, hostname string) { if !nv.Hostinfo().Valid() {