mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-10 11:14:21 +02:00
state, policy, noise: implement SSH check period auto-approval
Add SSH check period tracking so that recently authenticated users are auto-approved without requiring manual intervention each time. Introduce SSHCheckPeriod type with validation (min 1m, max 168h, "always" for every request) and encode the compiled check period as URL query parameters in the HoldAndDelegate URL. The SSHActionHandler checks recorded auth times before creating a new HoldAndDelegate flow. Auth timestamps are stored in-memory: - Default period (no explicit checkPeriod): auth covers any destination, keyed by source node with Dst=0 sentinel - Explicit period: auth covers only that specific destination, keyed by (source, destination) pair Auth times are cleared on policy changes. Updates #1850
This commit is contained in:
103
hscontrol/state/ssh_check_test.go
Normal file
103
hscontrol/state/ssh_check_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestStateForSSHCheck() *State {
|
||||
return &State{
|
||||
sshCheckAuth: make(map[sshCheckPair]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCheckAuth(t *testing.T) {
|
||||
s := newTestStateForSSHCheck()
|
||||
|
||||
src := types.NodeID(1)
|
||||
dst := types.NodeID(2)
|
||||
otherDst := types.NodeID(3)
|
||||
otherSrc := types.NodeID(4)
|
||||
|
||||
// No record initially
|
||||
_, ok := s.GetLastSSHAuth(src, dst)
|
||||
require.False(t, ok)
|
||||
|
||||
// Record auth for (src, dst)
|
||||
s.SetLastSSHAuth(src, dst)
|
||||
|
||||
// Same src+dst: found
|
||||
authTime, ok := s.GetLastSSHAuth(src, dst)
|
||||
require.True(t, ok)
|
||||
assert.WithinDuration(t, time.Now(), authTime, time.Second)
|
||||
|
||||
// Same src, different dst: not found (auth is per-pair)
|
||||
_, ok = s.GetLastSSHAuth(src, otherDst)
|
||||
require.False(t, ok)
|
||||
|
||||
// Different src: not found
|
||||
_, ok = s.GetLastSSHAuth(otherSrc, dst)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSSHCheckAuthClear(t *testing.T) {
|
||||
s := newTestStateForSSHCheck()
|
||||
|
||||
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
||||
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(3))
|
||||
|
||||
_, ok := s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
||||
require.True(t, ok)
|
||||
|
||||
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
|
||||
require.True(t, ok)
|
||||
|
||||
// Clear
|
||||
s.ClearSSHCheckAuth()
|
||||
|
||||
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
||||
require.False(t, ok)
|
||||
|
||||
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSSHCheckAuthConcurrent(t *testing.T) {
|
||||
s := newTestStateForSSHCheck()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := range 100 {
|
||||
wg.Go(func() {
|
||||
src := types.NodeID(uint64(i % 10)) //nolint:gosec
|
||||
dst := types.NodeID(uint64(i%5 + 10)) //nolint:gosec
|
||||
|
||||
s.SetLastSSHAuth(src, dst)
|
||||
s.GetLastSSHAuth(src, dst)
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Clear concurrently with reads
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
s.ClearSSHCheckAuth()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -67,6 +67,13 @@ var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||
// ErrRegistrationExpired is returned when a registration has expired.
|
||||
var ErrRegistrationExpired = errors.New("registration expired")
|
||||
|
||||
// sshCheckPair identifies a (source, destination) node pair for
|
||||
// SSH check auth tracking.
|
||||
type sshCheckPair struct {
|
||||
Src types.NodeID
|
||||
Dst types.NodeID
|
||||
}
|
||||
|
||||
// State manages Headscale's core state, coordinating between database, policy management,
|
||||
// IP allocation, and DERP routing. All methods are thread-safe.
|
||||
type State struct {
|
||||
@@ -91,6 +98,25 @@ type State struct {
|
||||
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
|
||||
// sshCheckAuth tracks when source nodes last completed SSH check auth.
|
||||
//
|
||||
// For rules without explicit checkPeriod (default 12h), auth covers any
|
||||
// destination — keyed by (src, Dst=0) where 0 is a sentinel meaning "any".
|
||||
// Ref: "Once re-authenticated to a destination, the user can access the
|
||||
// device and any other device in the tailnet without re-verification
|
||||
// for the next 12 hours." — https://tailscale.com/kb/1193/tailscale-ssh
|
||||
//
|
||||
// For rules with explicit checkPeriod, auth covers only that specific
|
||||
// destination — keyed by (src, dst).
|
||||
// Ref: "If a different check period is specified for the connection,
|
||||
// then the user can access specifically this device without
|
||||
// re-verification for the duration of the check period."
|
||||
//
|
||||
// Ref: https://github.com/tailscale/tailscale/issues/10480
|
||||
// Ref: https://github.com/tailscale/tailscale/issues/7125
|
||||
sshCheckAuth map[sshCheckPair]time.Time
|
||||
sshCheckMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewState creates and initializes a new State instance, setting up the database,
|
||||
@@ -189,6 +215,8 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
authCache: authCache,
|
||||
primaryRoutes: routes.New(),
|
||||
nodeStore: nodeStore,
|
||||
|
||||
sshCheckAuth: make(map[sshCheckPair]time.Time),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -227,6 +255,10 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
|
||||
return nil, fmt.Errorf("setting policy: %w", err)
|
||||
}
|
||||
|
||||
// Clear SSH check auth times when policy changes to ensure stale
|
||||
// approvals don't persist if checkPeriod rules are modified or removed.
|
||||
s.ClearSSHCheckAuth()
|
||||
|
||||
// Rebuild peer maps after policy changes because the peersFunc in NodeStore
|
||||
// uses the PolicyManager's filters. Without this, nodes won't see newly allowed
|
||||
// peers until a node is added/removed, causing autogroup:self policies to not
|
||||
@@ -874,6 +906,14 @@ func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
return s.polMan.SSHPolicy(s.cfg.ServerURL, node)
|
||||
}
|
||||
|
||||
// SSHCheckParams resolves the SSH check period for a source-destination
|
||||
// node pair from the current policy.
|
||||
func (s *State) SSHCheckParams(
|
||||
srcNodeID, dstNodeID types.NodeID,
|
||||
) (time.Duration, bool) {
|
||||
return s.polMan.SSHCheckParams(srcNodeID, dstNodeID)
|
||||
}
|
||||
|
||||
// Filter returns the current network filter rules and matches.
|
||||
func (s *State) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
return s.polMan.Filter()
|
||||
@@ -896,7 +936,15 @@ func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
|
||||
// SetPolicy updates the policy configuration.
|
||||
func (s *State) SetPolicy(pol []byte) (bool, error) {
|
||||
return s.polMan.SetPolicy(pol)
|
||||
changed, err := s.polMan.SetPolicy(pol)
|
||||
if err != nil {
|
||||
return changed, err
|
||||
}
|
||||
|
||||
// Clear SSH check auth times when policy changes.
|
||||
s.ClearSSHCheckAuth()
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// AutoApproveRoutes checks if a node's routes should be auto-approved.
|
||||
@@ -1077,6 +1125,35 @@ func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
||||
s.authCache.Set(id, entry)
|
||||
}
|
||||
|
||||
// SetLastSSHAuth records a successful SSH check authentication
|
||||
// for the given (src, dst) node pair.
|
||||
func (s *State) SetLastSSHAuth(src, dst types.NodeID) {
|
||||
s.sshCheckMu.Lock()
|
||||
defer s.sshCheckMu.Unlock()
|
||||
|
||||
s.sshCheckAuth[sshCheckPair{Src: src, Dst: dst}] = time.Now()
|
||||
}
|
||||
|
||||
// GetLastSSHAuth returns when src last authenticated for SSH check
|
||||
// to dst.
|
||||
func (s *State) GetLastSSHAuth(src, dst types.NodeID) (time.Time, bool) {
|
||||
s.sshCheckMu.RLock()
|
||||
defer s.sshCheckMu.RUnlock()
|
||||
|
||||
t, ok := s.sshCheckAuth[sshCheckPair{Src: src, Dst: dst}]
|
||||
|
||||
return t, ok
|
||||
}
|
||||
|
||||
// ClearSSHCheckAuth clears all recorded SSH check auth times.
|
||||
// Called when the policy changes to ensure stale auth times don't grant access.
|
||||
func (s *State) ClearSSHCheckAuth() {
|
||||
s.sshCheckMu.Lock()
|
||||
defer s.sshCheckMu.Unlock()
|
||||
|
||||
s.sshCheckAuth = make(map[sshCheckPair]time.Time)
|
||||
}
|
||||
|
||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
||||
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
||||
if !nv.Hostinfo().Valid() {
|
||||
|
||||
Reference in New Issue
Block a user