Files
headscale/hscontrol/state/ssh_check_test.go
Kristoffer Dalby 7bab8da366 state, policy, noise: implement SSH check period auto-approval
Add SSH check period tracking so that recently authenticated users
are auto-approved without requiring manual intervention each time.

Introduce SSHCheckPeriod type with validation (min 1m, max 168h,
"always" for every request) and encode the compiled check period
as URL query parameters in the HoldAndDelegate URL.

The SSHActionHandler checks recorded auth times before creating a
new HoldAndDelegate flow. Auth timestamps are stored in-memory:
- Default period (no explicit checkPeriod): auth covers any
  destination, keyed by source node with Dst=0 sentinel
- Explicit period: auth covers only that specific destination,
  keyed by (source, destination) pair

Auth times are cleared on policy changes.

Updates #1850
2026-02-25 21:28:05 +01:00

104 lines
2.0 KiB
Go

package state
import (
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestStateForSSHCheck() *State {
return &State{
sshCheckAuth: make(map[sshCheckPair]time.Time),
}
}
func TestSSHCheckAuth(t *testing.T) {
s := newTestStateForSSHCheck()
src := types.NodeID(1)
dst := types.NodeID(2)
otherDst := types.NodeID(3)
otherSrc := types.NodeID(4)
// No record initially
_, ok := s.GetLastSSHAuth(src, dst)
require.False(t, ok)
// Record auth for (src, dst)
s.SetLastSSHAuth(src, dst)
// Same src+dst: found
authTime, ok := s.GetLastSSHAuth(src, dst)
require.True(t, ok)
assert.WithinDuration(t, time.Now(), authTime, time.Second)
// Same src, different dst: not found (auth is per-pair)
_, ok = s.GetLastSSHAuth(src, otherDst)
require.False(t, ok)
// Different src: not found
_, ok = s.GetLastSSHAuth(otherSrc, dst)
require.False(t, ok)
}
func TestSSHCheckAuthClear(t *testing.T) {
s := newTestStateForSSHCheck()
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(2))
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(3))
_, ok := s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
require.True(t, ok)
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
require.True(t, ok)
// Clear
s.ClearSSHCheckAuth()
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
require.False(t, ok)
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
require.False(t, ok)
}
func TestSSHCheckAuthConcurrent(t *testing.T) {
s := newTestStateForSSHCheck()
var wg sync.WaitGroup
for i := range 100 {
wg.Go(func() {
src := types.NodeID(uint64(i % 10)) //nolint:gosec
dst := types.NodeID(uint64(i%5 + 10)) //nolint:gosec
s.SetLastSSHAuth(src, dst)
s.GetLastSSHAuth(src, dst)
})
}
wg.Wait()
// Clear concurrently with reads
wg.Add(2)
go func() {
defer wg.Done()
s.ClearSSHCheckAuth()
}()
go func() {
defer wg.Done()
s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
}()
wg.Wait()
}