hscontrol: add servertest harness for in-process control plane testing

Add a new hscontrol/servertest package that provides a test harness
for exercising the full Headscale control protocol in-process, using
Tailscale's controlclient.Direct as the client.

The harness consists of:
- TestServer: wraps a Headscale instance with an httptest.Server
- TestClient: wraps controlclient.Direct with NetworkMap tracking
- TestHarness: orchestrates N clients against a single server
- Assertion helpers for mesh completeness, visibility, and consistency

Export minimal accessor methods on Headscale (HTTPHandler, NoisePublicKey,
GetState, SetServerURL, StartBatcher, StartEphemeralGC) so the servertest
package can construct a working server from outside the hscontrol package.

This enables fast, deterministic tests of connection lifecycle, update
propagation, and network weather scenarios without Docker.
This commit is contained in:
Kristoffer Dalby
2026-03-16 09:16:43 +00:00
parent 82c7efccf8
commit 0288614bdf
5 changed files with 1039 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ import (
"strings"
"sync"
"syscall"
"testing"
"time"
"github.com/cenkalti/backoff/v5"
@@ -1069,6 +1070,56 @@ func (h *Headscale) Change(cs ...change.Change) {
h.mapBatcher.AddWork(cs...)
}
// HTTPHandler returns an http.Handler for the Headscale control server.
// The handler serves the Tailscale control protocol including the /key
// endpoint and /ts2021 Noise upgrade path.
func (h *Headscale) HTTPHandler() http.Handler {
return h.createRouter(grpcRuntime.NewServeMux())
}
// NoisePublicKey returns the server's Noise protocol public key.
func (h *Headscale) NoisePublicKey() key.MachinePublic {
return h.noisePrivateKey.Public()
}
// GetState returns the server's state manager for programmatic access
// to users, nodes, policies, and other server state.
func (h *Headscale) GetState() *state.State {
return h.state
}
// SetServerURLForTest updates the server URL in the configuration.
// This is needed for test servers where the URL is not known until
// the HTTP test server starts.
// It panics when called outside of tests.
func (h *Headscale) SetServerURLForTest(tb testing.TB, url string) {
tb.Helper()
h.cfg.ServerURL = url
}
// StartBatcherForTest initialises and starts the map response batcher.
// It registers a cleanup function on tb to stop the batcher.
// It panics when called outside of tests.
func (h *Headscale) StartBatcherForTest(tb testing.TB) {
tb.Helper()
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
h.mapBatcher.Start()
tb.Cleanup(func() { h.mapBatcher.Close() })
}
// StartEphemeralGCForTest starts the ephemeral node garbage collector.
// It registers a cleanup function on tb to stop the collector.
// It panics when called outside of tests.
func (h *Headscale) StartEphemeralGCForTest(tb testing.TB) {
tb.Helper()
go h.ephemeralGC.Start()
tb.Cleanup(func() { h.ephemeralGC.Close() })
}
// Provide some middleware that can inspect the ACME/autocert https calls
// and log when things are failing.
type acmeLogger struct {

View File

@@ -0,0 +1,219 @@
package servertest
import (
"net/netip"
"testing"
"time"
)
// AssertMeshComplete verifies that every client in the slice sees
// exactly (len(clients) - 1) peers, i.e. a fully connected mesh.
func AssertMeshComplete(tb testing.TB, clients []*TestClient) {
tb.Helper()
expected := len(clients) - 1
for _, c := range clients {
nm := c.Netmap()
if nm == nil {
tb.Errorf("AssertMeshComplete: %s has no netmap", c.Name)
continue
}
if got := len(nm.Peers); got != expected {
tb.Errorf("AssertMeshComplete: %s has %d peers, want %d (peers: %v)",
c.Name, got, expected, c.PeerNames())
}
}
}
// AssertSymmetricVisibility checks that peer visibility is symmetric:
// if client A sees client B, then client B must also see client A.
func AssertSymmetricVisibility(tb testing.TB, clients []*TestClient) {
tb.Helper()
for _, a := range clients {
for _, b := range clients {
if a == b {
continue
}
_, aSeesB := a.PeerByName(b.Name)
_, bSeesA := b.PeerByName(a.Name)
if aSeesB != bSeesA {
tb.Errorf("AssertSymmetricVisibility: %s sees %s = %v, but %s sees %s = %v",
a.Name, b.Name, aSeesB, b.Name, a.Name, bSeesA)
}
}
}
}
// AssertPeerOnline checks that the observer sees peerName as online.
func AssertPeerOnline(tb testing.TB, observer *TestClient, peerName string) {
tb.Helper()
peer, ok := observer.PeerByName(peerName)
if !ok {
tb.Errorf("AssertPeerOnline: %s does not see peer %s", observer.Name, peerName)
return
}
isOnline, known := peer.Online().GetOk()
if !known || !isOnline {
tb.Errorf("AssertPeerOnline: %s sees peer %s but Online=%v (known=%v), want true",
observer.Name, peerName, isOnline, known)
}
}
// AssertPeerOffline checks that the observer sees peerName as offline.
func AssertPeerOffline(tb testing.TB, observer *TestClient, peerName string) {
tb.Helper()
peer, ok := observer.PeerByName(peerName)
if !ok {
// Peer gone entirely counts as "offline" for this assertion.
return
}
isOnline, known := peer.Online().GetOk()
if known && isOnline {
tb.Errorf("AssertPeerOffline: %s sees peer %s as online, want offline",
observer.Name, peerName)
}
}
// AssertPeerGone checks that the observer does NOT have peerName in
// its peer list at all.
func AssertPeerGone(tb testing.TB, observer *TestClient, peerName string) {
tb.Helper()
_, ok := observer.PeerByName(peerName)
if ok {
tb.Errorf("AssertPeerGone: %s still sees peer %s", observer.Name, peerName)
}
}
// AssertPeerHasAllowedIPs checks that a peer has the expected
// AllowedIPs prefixes.
func AssertPeerHasAllowedIPs(tb testing.TB, observer *TestClient, peerName string, want []netip.Prefix) {
tb.Helper()
peer, ok := observer.PeerByName(peerName)
if !ok {
tb.Errorf("AssertPeerHasAllowedIPs: %s does not see peer %s", observer.Name, peerName)
return
}
got := make([]netip.Prefix, 0, peer.AllowedIPs().Len())
for i := range peer.AllowedIPs().Len() {
got = append(got, peer.AllowedIPs().At(i))
}
if len(got) != len(want) {
tb.Errorf("AssertPeerHasAllowedIPs: %s sees %s with AllowedIPs %v, want %v",
observer.Name, peerName, got, want)
return
}
// Build a set for comparison.
wantSet := make(map[netip.Prefix]bool, len(want))
for _, p := range want {
wantSet[p] = true
}
for _, p := range got {
if !wantSet[p] {
tb.Errorf("AssertPeerHasAllowedIPs: %s sees %s with unexpected AllowedIP %v (want %v)",
observer.Name, peerName, p, want)
}
}
}
// AssertConsistentState checks that all clients agree on peer
// properties: every connected client should see the same set of
// peer hostnames.
func AssertConsistentState(tb testing.TB, clients []*TestClient) {
tb.Helper()
for _, c := range clients {
nm := c.Netmap()
if nm == nil {
continue
}
peerNames := make(map[string]bool, len(nm.Peers))
for _, p := range nm.Peers {
hi := p.Hostinfo()
if hi.Valid() {
peerNames[hi.Hostname()] = true
}
}
// Check that c sees all other connected clients.
for _, other := range clients {
if other == c || other.Netmap() == nil {
continue
}
if !peerNames[other.Name] {
tb.Errorf("AssertConsistentState: %s does not see %s (peers: %v)",
c.Name, other.Name, c.PeerNames())
}
}
}
}
// EventuallyAssertMeshComplete retries AssertMeshComplete up to
// timeout, useful when waiting for state to propagate.
func EventuallyAssertMeshComplete(tb testing.TB, clients []*TestClient, timeout time.Duration) {
tb.Helper()
expected := len(clients) - 1
deadline := time.After(timeout)
for {
allGood := true
for _, c := range clients {
nm := c.Netmap()
if nm == nil || len(nm.Peers) < expected {
allGood = false
break
}
}
if allGood {
// Final strict check.
AssertMeshComplete(tb, clients)
return
}
select {
case <-deadline:
// Report the failure with details.
for _, c := range clients {
nm := c.Netmap()
got := 0
if nm != nil {
got = len(nm.Peers)
}
if got != expected {
tb.Errorf("EventuallyAssertMeshComplete: %s has %d peers, want %d (timeout %v)",
c.Name, got, expected, timeout)
}
}
return
case <-time.After(100 * time.Millisecond):
// Poll again.
}
}
}

View File

@@ -0,0 +1,430 @@
package servertest
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/control/controlclient"
"tailscale.com/health"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/types/persist"
"tailscale.com/util/eventbus"
)
// TestClient wraps a Tailscale controlclient.Direct connected to a
// TestServer. It tracks all received NetworkMap updates, providing
// helpers to wait for convergence and inspect the client's view of
// the network.
type TestClient struct {
// Name is a human-readable identifier for this client.
Name string
server *TestServer
direct *controlclient.Direct
authKey string
user *types.User
// Connection lifecycle.
pollCtx context.Context //nolint:containedctx // test-only; context stored for cancel control
pollCancel context.CancelFunc
pollDone chan struct{}
// Accumulated state from MapResponse callbacks.
mu sync.RWMutex
netmap *netmap.NetworkMap
history []*netmap.NetworkMap
// updates is a buffered channel that receives a signal
// each time a new NetworkMap arrives.
updates chan *netmap.NetworkMap
bus *eventbus.Bus
dialer *tsdial.Dialer
tracker *health.Tracker
}
// ClientOption configures a TestClient.
type ClientOption func(*clientConfig)
type clientConfig struct {
ephemeral bool
hostname string
tags []string
user *types.User
}
// WithEphemeral makes the client register as an ephemeral node.
func WithEphemeral() ClientOption {
return func(c *clientConfig) { c.ephemeral = true }
}
// WithHostname sets the client's hostname in Hostinfo.
func WithHostname(name string) ClientOption {
return func(c *clientConfig) { c.hostname = name }
}
// WithTags sets ACL tags on the pre-auth key.
func WithTags(tags ...string) ClientOption {
return func(c *clientConfig) { c.tags = tags }
}
// WithUser sets the user for the client. If not set, the harness
// creates a default user.
func WithUser(user *types.User) ClientOption {
return func(c *clientConfig) { c.user = user }
}
// NewClient creates a TestClient, registers it with the TestServer
// using a pre-auth key, and starts long-polling for map updates.
func NewClient(tb testing.TB, server *TestServer, name string, opts ...ClientOption) *TestClient {
tb.Helper()
cc := &clientConfig{
hostname: name,
}
for _, o := range opts {
o(cc)
}
// Resolve user.
user := cc.user
if user == nil {
// Create a per-client user if none specified.
user = server.CreateUser(tb, "user-"+name)
}
// Create pre-auth key.
uid := types.UserID(user.ID)
var authKey string
if cc.ephemeral {
authKey = server.CreateEphemeralPreAuthKey(tb, uid)
} else {
authKey = server.CreatePreAuthKey(tb, uid)
}
// Set up Tailscale client infrastructure.
bus := eventbus.New()
tracker := health.NewTracker(bus)
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
machineKey := key.NewMachine()
direct, err := controlclient.NewDirect(controlclient.Options{
Persist: persist.Persist{},
GetMachinePrivateKey: func() (key.MachinePrivate, error) { return machineKey, nil },
ServerURL: server.URL,
AuthKey: authKey,
Hostinfo: &tailcfg.Hostinfo{
BackendLogID: "servertest-" + name,
Hostname: cc.hostname,
},
DiscoPublicKey: key.NewDisco().Public(),
Logf: tb.Logf,
HealthTracker: tracker,
Dialer: dialer,
Bus: bus,
})
if err != nil {
tb.Fatalf("servertest: NewDirect(%s): %v", name, err)
}
tc := &TestClient{
Name: name,
server: server,
direct: direct,
authKey: authKey,
user: user,
updates: make(chan *netmap.NetworkMap, 64),
bus: bus,
dialer: dialer,
tracker: tracker,
}
tb.Cleanup(func() {
tc.cleanup()
})
// Register with the server.
tc.register(tb)
// Start long-polling in the background.
tc.startPoll(tb)
return tc
}
// register performs the initial TryLogin to register the client.
func (c *TestClient) register(tb testing.TB) {
tb.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
url, err := c.direct.TryLogin(ctx, controlclient.LoginDefault)
if err != nil {
tb.Fatalf("servertest: TryLogin(%s): %v", c.Name, err)
}
if url != "" {
tb.Fatalf("servertest: TryLogin(%s): unexpected auth URL: %s (expected auto-auth with preauth key)", c.Name, url)
}
}
// startPoll begins the long-poll MapRequest loop.
func (c *TestClient) startPoll(tb testing.TB) {
tb.Helper()
c.pollCtx, c.pollCancel = context.WithCancel(context.Background())
c.pollDone = make(chan struct{})
go func() {
defer close(c.pollDone)
// PollNetMap blocks until ctx is cancelled or the server closes
// the connection.
_ = c.direct.PollNetMap(c.pollCtx, c)
}()
}
// UpdateFullNetmap implements controlclient.NetmapUpdater.
// Called by controlclient.Direct when a new NetworkMap is received.
func (c *TestClient) UpdateFullNetmap(nm *netmap.NetworkMap) {
c.mu.Lock()
c.netmap = nm
c.history = append(c.history, nm)
c.mu.Unlock()
// Non-blocking send to the updates channel.
select {
case c.updates <- nm:
default:
}
}
// cleanup releases all resources.
func (c *TestClient) cleanup() {
if c.pollCancel != nil {
c.pollCancel()
}
if c.pollDone != nil {
// Wait for PollNetMap to exit, but don't hang.
select {
case <-c.pollDone:
case <-time.After(5 * time.Second):
}
}
if c.direct != nil {
c.direct.Close()
}
if c.dialer != nil {
c.dialer.Close()
}
if c.bus != nil {
c.bus.Close()
}
}
// --- Lifecycle methods ---
// Disconnect cancels the long-poll context, simulating a clean
// client disconnect.
func (c *TestClient) Disconnect(tb testing.TB) {
tb.Helper()
if c.pollCancel != nil {
c.pollCancel()
<-c.pollDone
}
}
// Reconnect registers and starts a new long-poll session.
// Call Disconnect first, or this will disconnect automatically.
func (c *TestClient) Reconnect(tb testing.TB) {
tb.Helper()
// Cancel any existing poll.
if c.pollCancel != nil {
c.pollCancel()
select {
case <-c.pollDone:
case <-time.After(5 * time.Second):
tb.Fatalf("servertest: Reconnect(%s): old poll did not exit", c.Name)
}
}
// Re-register and start polling again.
c.register(tb)
c.startPoll(tb)
}
// ReconnectAfter disconnects, waits for d, then reconnects.
// The timer works correctly with testing/synctest for
// time-controlled tests.
func (c *TestClient) ReconnectAfter(tb testing.TB, d time.Duration) {
tb.Helper()
c.Disconnect(tb)
timer := time.NewTimer(d)
defer timer.Stop()
<-timer.C
c.Reconnect(tb)
}
// --- State accessors ---
// Netmap returns the latest NetworkMap, or nil if none received yet.
func (c *TestClient) Netmap() *netmap.NetworkMap {
c.mu.RLock()
defer c.mu.RUnlock()
return c.netmap
}
// WaitForPeers blocks until the client sees at least n peers,
// or until timeout expires.
func (c *TestClient) WaitForPeers(tb testing.TB, n int, timeout time.Duration) {
tb.Helper()
deadline := time.After(timeout)
for {
if nm := c.Netmap(); nm != nil && len(nm.Peers) >= n {
return
}
select {
case <-c.updates:
// Check again.
case <-deadline:
nm := c.Netmap()
got := 0
if nm != nil {
got = len(nm.Peers)
}
tb.Fatalf("servertest: WaitForPeers(%s, %d): timeout after %v (got %d peers)", c.Name, n, timeout, got)
}
}
}
// WaitForUpdate blocks until the next netmap update arrives or timeout.
func (c *TestClient) WaitForUpdate(tb testing.TB, timeout time.Duration) *netmap.NetworkMap {
tb.Helper()
select {
case nm := <-c.updates:
return nm
case <-time.After(timeout):
tb.Fatalf("servertest: WaitForUpdate(%s): timeout after %v", c.Name, timeout)
return nil
}
}
// Peers returns the current peer list, or nil.
func (c *TestClient) Peers() []tailcfg.NodeView {
c.mu.RLock()
defer c.mu.RUnlock()
if c.netmap == nil {
return nil
}
return c.netmap.Peers
}
// PeerByName finds a peer by hostname. Returns the peer and true
// if found, zero value and false otherwise.
func (c *TestClient) PeerByName(hostname string) (tailcfg.NodeView, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.netmap == nil {
return tailcfg.NodeView{}, false
}
for _, p := range c.netmap.Peers {
hi := p.Hostinfo()
if hi.Valid() && hi.Hostname() == hostname {
return p, true
}
}
return tailcfg.NodeView{}, false
}
// PeerNames returns the hostnames of all current peers.
func (c *TestClient) PeerNames() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.netmap == nil {
return nil
}
names := make([]string, 0, len(c.netmap.Peers))
for _, p := range c.netmap.Peers {
hi := p.Hostinfo()
if hi.Valid() {
names = append(names, hi.Hostname())
}
}
return names
}
// UpdateCount returns the total number of full netmap updates received.
func (c *TestClient) UpdateCount() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.history)
}
// History returns a copy of all NetworkMap snapshots in order.
func (c *TestClient) History() []*netmap.NetworkMap {
c.mu.RLock()
defer c.mu.RUnlock()
out := make([]*netmap.NetworkMap, len(c.history))
copy(out, c.history)
return out
}
// SelfName returns the self node's hostname from the latest netmap.
func (c *TestClient) SelfName() string {
nm := c.Netmap()
if nm == nil || !nm.SelfNode.Valid() {
return ""
}
return nm.SelfNode.Hostinfo().Hostname()
}
// String implements fmt.Stringer for debug output.
func (c *TestClient) String() string {
nm := c.Netmap()
if nm == nil {
return fmt.Sprintf("TestClient(%s, no netmap)", c.Name)
}
return fmt.Sprintf("TestClient(%s, %d peers)", c.Name, len(nm.Peers))
}

View File

@@ -0,0 +1,157 @@
package servertest
import (
"fmt"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
)
// TestHarness orchestrates a TestServer with multiple TestClients,
// providing a convenient setup for multi-node control plane tests.
type TestHarness struct {
Server *TestServer
clients []*TestClient
// Default user shared by all clients unless overridden.
defaultUser *types.User
}
// HarnessOption configures a TestHarness.
type HarnessOption func(*harnessConfig)
type harnessConfig struct {
serverOpts []ServerOption
clientOpts []ClientOption
convergenceMax time.Duration
}
func defaultHarnessConfig() *harnessConfig {
return &harnessConfig{
convergenceMax: 30 * time.Second,
}
}
// WithServerOptions passes ServerOptions through to the underlying
// TestServer.
func WithServerOptions(opts ...ServerOption) HarnessOption {
return func(c *harnessConfig) { c.serverOpts = append(c.serverOpts, opts...) }
}
// WithDefaultClientOptions applies ClientOptions to every client
// created by NewHarness.
func WithDefaultClientOptions(opts ...ClientOption) HarnessOption {
return func(c *harnessConfig) { c.clientOpts = append(c.clientOpts, opts...) }
}
// WithConvergenceTimeout sets how long WaitForMeshComplete waits.
func WithConvergenceTimeout(d time.Duration) HarnessOption {
return func(c *harnessConfig) { c.convergenceMax = d }
}
// NewHarness creates a TestServer and numClients connected clients.
// All clients share a default user and are registered with reusable
// pre-auth keys. The harness waits for all clients to form a
// complete mesh before returning.
func NewHarness(tb testing.TB, numClients int, opts ...HarnessOption) *TestHarness {
tb.Helper()
hc := defaultHarnessConfig()
for _, o := range opts {
o(hc)
}
server := NewServer(tb, hc.serverOpts...)
// Create a shared default user.
user := server.CreateUser(tb, "harness-default")
h := &TestHarness{
Server: server,
defaultUser: user,
}
// Create and connect clients.
for i := range numClients {
name := clientName(i)
copts := append([]ClientOption{WithUser(user)}, hc.clientOpts...)
c := NewClient(tb, server, name, copts...)
h.clients = append(h.clients, c)
}
// Wait for the mesh to converge.
if numClients > 1 {
h.WaitForMeshComplete(tb, hc.convergenceMax)
} else if numClients == 1 {
// Single node: just wait for the first netmap.
h.clients[0].WaitForUpdate(tb, hc.convergenceMax)
}
return h
}
// Client returns the i-th client (0-indexed).
func (h *TestHarness) Client(i int) *TestClient {
return h.clients[i]
}
// Clients returns all clients.
func (h *TestHarness) Clients() []*TestClient {
return h.clients
}
// ConnectedClients returns clients that currently have an active
// long-poll session (pollDone channel is still open).
func (h *TestHarness) ConnectedClients() []*TestClient {
var out []*TestClient
for _, c := range h.clients {
select {
case <-c.pollDone:
// Poll has ended, client is disconnected.
default:
out = append(out, c)
}
}
return out
}
// AddClient creates and connects a new client to the existing mesh.
func (h *TestHarness) AddClient(tb testing.TB, opts ...ClientOption) *TestClient {
tb.Helper()
name := clientName(len(h.clients))
copts := append([]ClientOption{WithUser(h.defaultUser)}, opts...)
c := NewClient(tb, h.Server, name, copts...)
h.clients = append(h.clients, c)
return c
}
// WaitForMeshComplete blocks until every connected client sees
// (connectedCount - 1) peers.
func (h *TestHarness) WaitForMeshComplete(tb testing.TB, timeout time.Duration) {
tb.Helper()
connected := h.ConnectedClients()
expectedPeers := max(len(connected)-1, 0)
for _, c := range connected {
c.WaitForPeers(tb, expectedPeers, timeout)
}
}
// WaitForConvergence waits until all connected clients have a
// non-nil NetworkMap and their peer counts have stabilised.
func (h *TestHarness) WaitForConvergence(tb testing.TB, timeout time.Duration) {
tb.Helper()
h.WaitForMeshComplete(tb, timeout)
}
func clientName(index int) string {
return fmt.Sprintf("node-%d", index)
}

View File

@@ -0,0 +1,182 @@
// Package servertest provides an in-process test harness for Headscale's
// control plane. It wires a real Headscale server to real Tailscale
// controlclient.Direct instances, enabling fast, deterministic tests
// of the full control protocol without Docker or separate processes.
package servertest
import (
"net/http/httptest"
"testing"
"time"
hscontrol "github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
)
// TestServer is an in-process Headscale control server suitable for
// use with Tailscale's controlclient.Direct.
type TestServer struct {
App *hscontrol.Headscale
HTTPServer *httptest.Server
URL string
st *state.State
}
// ServerOption configures a TestServer.
type ServerOption func(*serverConfig)
type serverConfig struct {
batchDelay time.Duration
bufferedChanSize int
ephemeralTimeout time.Duration
batcherWorkers int
}
func defaultServerConfig() *serverConfig {
return &serverConfig{
batchDelay: 50 * time.Millisecond,
batcherWorkers: 1,
ephemeralTimeout: 30 * time.Second,
}
}
// WithBatchDelay sets the batcher's change coalescing delay.
func WithBatchDelay(d time.Duration) ServerOption {
return func(c *serverConfig) { c.batchDelay = d }
}
// WithBufferedChanSize sets the per-node map session channel buffer.
func WithBufferedChanSize(n int) ServerOption {
return func(c *serverConfig) { c.bufferedChanSize = n }
}
// WithEphemeralTimeout sets the ephemeral node inactivity timeout.
func WithEphemeralTimeout(d time.Duration) ServerOption {
return func(c *serverConfig) { c.ephemeralTimeout = d }
}
// NewServer creates and starts a Headscale test server.
// The server is fully functional and accepts real Tailscale control
// protocol connections over Noise.
func NewServer(tb testing.TB, opts ...ServerOption) *TestServer {
tb.Helper()
sc := defaultServerConfig()
for _, o := range opts {
o(sc)
}
tmpDir := tb.TempDir()
cfg := types.Config{
// Placeholder; updated below once httptest server starts.
ServerURL: "http://localhost:0",
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
EphemeralNodeInactivityTimeout: sc.ephemeralTimeout,
Database: types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
},
Tuning: types.Tuning{
BatchChangeDelay: sc.batchDelay,
BatcherWorkers: sc.batcherWorkers,
NodeMapSessionBufferedChanSize: sc.bufferedChanSize,
},
}
app, err := hscontrol.NewHeadscale(&cfg)
if err != nil {
tb.Fatalf("servertest: NewHeadscale: %v", err)
}
// Set a minimal DERP map so MapResponse generation works.
app.GetState().SetDERPMap(&tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
900: {
RegionID: 900,
RegionCode: "test",
RegionName: "Test Region",
Nodes: []*tailcfg.DERPNode{{
Name: "test0",
RegionID: 900,
HostName: "127.0.0.1",
IPv4: "127.0.0.1",
DERPPort: -1, // not a real DERP, just needed for MapResponse
}},
},
},
})
// Start subsystems.
app.StartBatcherForTest(tb)
app.StartEphemeralGCForTest(tb)
// Start the HTTP server with Headscale's full handler (including
// /key and /ts2021 Noise upgrade).
ts := httptest.NewServer(app.HTTPHandler())
tb.Cleanup(ts.Close)
// Now update the config to point at the real URL so that
// MapResponse.ControlURL etc. are correct.
app.SetServerURLForTest(tb, ts.URL)
return &TestServer{
App: app,
HTTPServer: ts,
URL: ts.URL,
st: app.GetState(),
}
}
// State returns the server's state manager for creating users,
// nodes, and pre-auth keys.
func (s *TestServer) State() *state.State {
return s.st
}
// CreateUser creates a test user and returns it.
func (s *TestServer) CreateUser(tb testing.TB, name string) *types.User {
tb.Helper()
u, _, err := s.st.CreateUser(types.User{Name: name})
if err != nil {
tb.Fatalf("servertest: CreateUser(%q): %v", name, err)
}
return u
}
// CreatePreAuthKey creates a reusable pre-auth key for the given user.
func (s *TestServer) CreatePreAuthKey(tb testing.TB, userID types.UserID) string {
tb.Helper()
uid := userID
pak, err := s.st.CreatePreAuthKey(&uid, true, false, nil, nil)
if err != nil {
tb.Fatalf("servertest: CreatePreAuthKey: %v", err)
}
return pak.Key
}
// CreateEphemeralPreAuthKey creates an ephemeral pre-auth key.
func (s *TestServer) CreateEphemeralPreAuthKey(tb testing.TB, userID types.UserID) string {
tb.Helper()
uid := userID
pak, err := s.st.CreatePreAuthKey(&uid, false, true, nil, nil)
if err != nil {
tb.Fatalf("servertest: CreateEphemeralPreAuthKey: %v", err)
}
return pak.Key
}