mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-17 23:14:01 +01:00
Add SSH check period tracking so that recently authenticated users are auto-approved without requiring manual intervention each time. Introduce SSHCheckPeriod type with validation (min 1m, max 168h, "always" for every request) and encode the compiled check period as URL query parameters in the HoldAndDelegate URL. The SSHActionHandler checks recorded auth times before creating a new HoldAndDelegate flow. Auth timestamps are stored in-memory: - Default period (no explicit checkPeriod): auth covers any destination, keyed by source node with Dst=0 sentinel - Explicit period: auth covers only that specific destination, keyed by (source, destination) pair Auth times are cleared on policy changes. Updates #1850
104 lines
2.0 KiB
Go
104 lines
2.0 KiB
Go
package state
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func newTestStateForSSHCheck() *State {
|
|
return &State{
|
|
sshCheckAuth: make(map[sshCheckPair]time.Time),
|
|
}
|
|
}
|
|
|
|
func TestSSHCheckAuth(t *testing.T) {
|
|
s := newTestStateForSSHCheck()
|
|
|
|
src := types.NodeID(1)
|
|
dst := types.NodeID(2)
|
|
otherDst := types.NodeID(3)
|
|
otherSrc := types.NodeID(4)
|
|
|
|
// No record initially
|
|
_, ok := s.GetLastSSHAuth(src, dst)
|
|
require.False(t, ok)
|
|
|
|
// Record auth for (src, dst)
|
|
s.SetLastSSHAuth(src, dst)
|
|
|
|
// Same src+dst: found
|
|
authTime, ok := s.GetLastSSHAuth(src, dst)
|
|
require.True(t, ok)
|
|
assert.WithinDuration(t, time.Now(), authTime, time.Second)
|
|
|
|
// Same src, different dst: not found (auth is per-pair)
|
|
_, ok = s.GetLastSSHAuth(src, otherDst)
|
|
require.False(t, ok)
|
|
|
|
// Different src: not found
|
|
_, ok = s.GetLastSSHAuth(otherSrc, dst)
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func TestSSHCheckAuthClear(t *testing.T) {
|
|
s := newTestStateForSSHCheck()
|
|
|
|
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
|
s.SetLastSSHAuth(types.NodeID(1), types.NodeID(3))
|
|
|
|
_, ok := s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
|
require.True(t, ok)
|
|
|
|
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
|
|
require.True(t, ok)
|
|
|
|
// Clear
|
|
s.ClearSSHCheckAuth()
|
|
|
|
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
|
require.False(t, ok)
|
|
|
|
_, ok = s.GetLastSSHAuth(types.NodeID(1), types.NodeID(3))
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func TestSSHCheckAuthConcurrent(t *testing.T) {
|
|
s := newTestStateForSSHCheck()
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
for i := range 100 {
|
|
wg.Go(func() {
|
|
src := types.NodeID(uint64(i % 10)) //nolint:gosec
|
|
dst := types.NodeID(uint64(i%5 + 10)) //nolint:gosec
|
|
|
|
s.SetLastSSHAuth(src, dst)
|
|
s.GetLastSSHAuth(src, dst)
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Clear concurrently with reads
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
s.ClearSSHCheckAuth()
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
s.GetLastSSHAuth(types.NodeID(1), types.NodeID(2))
|
|
}()
|
|
|
|
wg.Wait()
|
|
}
|