From 0288614bdfd3c7acd264060f65924880ca8cfa47 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 16 Mar 2026 09:16:43 +0000 Subject: [PATCH] hscontrol: add servertest harness for in-process control plane testing Add a new hscontrol/servertest package that provides a test harness for exercising the full Headscale control protocol in-process, using Tailscale's controlclient.Direct as the client. The harness consists of: - TestServer: wraps a Headscale instance with an httptest.Server - TestClient: wraps controlclient.Direct with NetworkMap tracking - TestHarness: orchestrates N clients against a single server - Assertion helpers for mesh completeness, visibility, and consistency Export minimal accessor methods on Headscale (HTTPHandler, NoisePublicKey, GetState, SetServerURL, StartBatcher, StartEphemeralGC) so the servertest package can construct a working server from outside the hscontrol package. This enables fast, deterministic tests of connection lifecycle, update propagation, and network weather scenarios without Docker. --- hscontrol/app.go | 51 ++++ hscontrol/servertest/assertions.go | 219 +++++++++++++++ hscontrol/servertest/client.go | 430 +++++++++++++++++++++++++++++ hscontrol/servertest/harness.go | 157 +++++++++++ hscontrol/servertest/server.go | 182 ++++++++++++ 5 files changed, 1039 insertions(+) create mode 100644 hscontrol/servertest/assertions.go create mode 100644 hscontrol/servertest/client.go create mode 100644 hscontrol/servertest/harness.go create mode 100644 hscontrol/servertest/server.go diff --git a/hscontrol/app.go b/hscontrol/app.go index c57c7be0..ab017c47 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -16,6 +16,7 @@ import ( "strings" "sync" "syscall" + "testing" "time" "github.com/cenkalti/backoff/v5" @@ -1069,6 +1070,56 @@ func (h *Headscale) Change(cs ...change.Change) { h.mapBatcher.AddWork(cs...) } +// HTTPHandler returns an http.Handler for the Headscale control server. +// The handler serves the Tailscale control protocol including the /key +// endpoint and /ts2021 Noise upgrade path. +func (h *Headscale) HTTPHandler() http.Handler { + return h.createRouter(grpcRuntime.NewServeMux()) +} + +// NoisePublicKey returns the server's Noise protocol public key. +func (h *Headscale) NoisePublicKey() key.MachinePublic { + return h.noisePrivateKey.Public() +} + +// GetState returns the server's state manager for programmatic access +// to users, nodes, policies, and other server state. +func (h *Headscale) GetState() *state.State { + return h.state +} + +// SetServerURLForTest updates the server URL in the configuration. +// This is needed for test servers where the URL is not known until +// the HTTP test server starts. +// It panics when called outside of tests. +func (h *Headscale) SetServerURLForTest(tb testing.TB, url string) { + tb.Helper() + + h.cfg.ServerURL = url +} + +// StartBatcherForTest initialises and starts the map response batcher. +// It registers a cleanup function on tb to stop the batcher. +// It panics when called outside of tests. +func (h *Headscale) StartBatcherForTest(tb testing.TB) { + tb.Helper() + + h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() + tb.Cleanup(func() { h.mapBatcher.Close() }) +} + +// StartEphemeralGCForTest starts the ephemeral node garbage collector. +// It registers a cleanup function on tb to stop the collector. +// It panics when called outside of tests. +func (h *Headscale) StartEphemeralGCForTest(tb testing.TB) { + tb.Helper() + + go h.ephemeralGC.Start() + + tb.Cleanup(func() { h.ephemeralGC.Close() }) +} + // Provide some middleware that can inspect the ACME/autocert https calls // and log when things are failing. type acmeLogger struct { diff --git a/hscontrol/servertest/assertions.go b/hscontrol/servertest/assertions.go new file mode 100644 index 00000000..7933fe43 --- /dev/null +++ b/hscontrol/servertest/assertions.go @@ -0,0 +1,219 @@ +package servertest + +import ( + "net/netip" + "testing" + "time" +) + +// AssertMeshComplete verifies that every client in the slice sees +// exactly (len(clients) - 1) peers, i.e. a fully connected mesh. +func AssertMeshComplete(tb testing.TB, clients []*TestClient) { + tb.Helper() + + expected := len(clients) - 1 + for _, c := range clients { + nm := c.Netmap() + if nm == nil { + tb.Errorf("AssertMeshComplete: %s has no netmap", c.Name) + + continue + } + + if got := len(nm.Peers); got != expected { + tb.Errorf("AssertMeshComplete: %s has %d peers, want %d (peers: %v)", + c.Name, got, expected, c.PeerNames()) + } + } +} + +// AssertSymmetricVisibility checks that peer visibility is symmetric: +// if client A sees client B, then client B must also see client A. +func AssertSymmetricVisibility(tb testing.TB, clients []*TestClient) { + tb.Helper() + + for _, a := range clients { + for _, b := range clients { + if a == b { + continue + } + + _, aSeesB := a.PeerByName(b.Name) + + _, bSeesA := b.PeerByName(a.Name) + if aSeesB != bSeesA { + tb.Errorf("AssertSymmetricVisibility: %s sees %s = %v, but %s sees %s = %v", + a.Name, b.Name, aSeesB, b.Name, a.Name, bSeesA) + } + } + } +} + +// AssertPeerOnline checks that the observer sees peerName as online. +func AssertPeerOnline(tb testing.TB, observer *TestClient, peerName string) { + tb.Helper() + + peer, ok := observer.PeerByName(peerName) + if !ok { + tb.Errorf("AssertPeerOnline: %s does not see peer %s", observer.Name, peerName) + + return + } + + isOnline, known := peer.Online().GetOk() + if !known || !isOnline { + tb.Errorf("AssertPeerOnline: %s sees peer %s but Online=%v (known=%v), want true", + observer.Name, peerName, isOnline, known) + } +} + +// AssertPeerOffline checks that the observer sees peerName as offline. +func AssertPeerOffline(tb testing.TB, observer *TestClient, peerName string) { + tb.Helper() + + peer, ok := observer.PeerByName(peerName) + if !ok { + // Peer gone entirely counts as "offline" for this assertion. + return + } + + isOnline, known := peer.Online().GetOk() + if known && isOnline { + tb.Errorf("AssertPeerOffline: %s sees peer %s as online, want offline", + observer.Name, peerName) + } +} + +// AssertPeerGone checks that the observer does NOT have peerName in +// its peer list at all. +func AssertPeerGone(tb testing.TB, observer *TestClient, peerName string) { + tb.Helper() + + _, ok := observer.PeerByName(peerName) + if ok { + tb.Errorf("AssertPeerGone: %s still sees peer %s", observer.Name, peerName) + } +} + +// AssertPeerHasAllowedIPs checks that a peer has the expected +// AllowedIPs prefixes. +func AssertPeerHasAllowedIPs(tb testing.TB, observer *TestClient, peerName string, want []netip.Prefix) { + tb.Helper() + + peer, ok := observer.PeerByName(peerName) + if !ok { + tb.Errorf("AssertPeerHasAllowedIPs: %s does not see peer %s", observer.Name, peerName) + + return + } + + got := make([]netip.Prefix, 0, peer.AllowedIPs().Len()) + for i := range peer.AllowedIPs().Len() { + got = append(got, peer.AllowedIPs().At(i)) + } + + if len(got) != len(want) { + tb.Errorf("AssertPeerHasAllowedIPs: %s sees %s with AllowedIPs %v, want %v", + observer.Name, peerName, got, want) + + return + } + + // Build a set for comparison. + wantSet := make(map[netip.Prefix]bool, len(want)) + for _, p := range want { + wantSet[p] = true + } + + for _, p := range got { + if !wantSet[p] { + tb.Errorf("AssertPeerHasAllowedIPs: %s sees %s with unexpected AllowedIP %v (want %v)", + observer.Name, peerName, p, want) + } + } +} + +// AssertConsistentState checks that all clients agree on peer +// properties: every connected client should see the same set of +// peer hostnames. +func AssertConsistentState(tb testing.TB, clients []*TestClient) { + tb.Helper() + + for _, c := range clients { + nm := c.Netmap() + if nm == nil { + continue + } + + peerNames := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() { + peerNames[hi.Hostname()] = true + } + } + + // Check that c sees all other connected clients. + for _, other := range clients { + if other == c || other.Netmap() == nil { + continue + } + + if !peerNames[other.Name] { + tb.Errorf("AssertConsistentState: %s does not see %s (peers: %v)", + c.Name, other.Name, c.PeerNames()) + } + } + } +} + +// EventuallyAssertMeshComplete retries AssertMeshComplete up to +// timeout, useful when waiting for state to propagate. +func EventuallyAssertMeshComplete(tb testing.TB, clients []*TestClient, timeout time.Duration) { + tb.Helper() + + expected := len(clients) - 1 + deadline := time.After(timeout) + + for { + allGood := true + + for _, c := range clients { + nm := c.Netmap() + if nm == nil || len(nm.Peers) < expected { + allGood = false + + break + } + } + + if allGood { + // Final strict check. + AssertMeshComplete(tb, clients) + + return + } + + select { + case <-deadline: + // Report the failure with details. + for _, c := range clients { + nm := c.Netmap() + + got := 0 + if nm != nil { + got = len(nm.Peers) + } + + if got != expected { + tb.Errorf("EventuallyAssertMeshComplete: %s has %d peers, want %d (timeout %v)", + c.Name, got, expected, timeout) + } + } + + return + case <-time.After(100 * time.Millisecond): + // Poll again. + } + } +} diff --git a/hscontrol/servertest/client.go b/hscontrol/servertest/client.go new file mode 100644 index 00000000..9ae17a3d --- /dev/null +++ b/hscontrol/servertest/client.go @@ -0,0 +1,430 @@ +package servertest + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/control/controlclient" + "tailscale.com/health" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/types/persist" + "tailscale.com/util/eventbus" +) + +// TestClient wraps a Tailscale controlclient.Direct connected to a +// TestServer. It tracks all received NetworkMap updates, providing +// helpers to wait for convergence and inspect the client's view of +// the network. +type TestClient struct { + // Name is a human-readable identifier for this client. + Name string + + server *TestServer + direct *controlclient.Direct + authKey string + user *types.User + + // Connection lifecycle. + pollCtx context.Context //nolint:containedctx // test-only; context stored for cancel control + pollCancel context.CancelFunc + pollDone chan struct{} + + // Accumulated state from MapResponse callbacks. + mu sync.RWMutex + netmap *netmap.NetworkMap + history []*netmap.NetworkMap + + // updates is a buffered channel that receives a signal + // each time a new NetworkMap arrives. + updates chan *netmap.NetworkMap + + bus *eventbus.Bus + dialer *tsdial.Dialer + tracker *health.Tracker +} + +// ClientOption configures a TestClient. +type ClientOption func(*clientConfig) + +type clientConfig struct { + ephemeral bool + hostname string + tags []string + user *types.User +} + +// WithEphemeral makes the client register as an ephemeral node. +func WithEphemeral() ClientOption { + return func(c *clientConfig) { c.ephemeral = true } +} + +// WithHostname sets the client's hostname in Hostinfo. +func WithHostname(name string) ClientOption { + return func(c *clientConfig) { c.hostname = name } +} + +// WithTags sets ACL tags on the pre-auth key. +func WithTags(tags ...string) ClientOption { + return func(c *clientConfig) { c.tags = tags } +} + +// WithUser sets the user for the client. If not set, the harness +// creates a default user. +func WithUser(user *types.User) ClientOption { + return func(c *clientConfig) { c.user = user } +} + +// NewClient creates a TestClient, registers it with the TestServer +// using a pre-auth key, and starts long-polling for map updates. +func NewClient(tb testing.TB, server *TestServer, name string, opts ...ClientOption) *TestClient { + tb.Helper() + + cc := &clientConfig{ + hostname: name, + } + for _, o := range opts { + o(cc) + } + + // Resolve user. + user := cc.user + if user == nil { + // Create a per-client user if none specified. + user = server.CreateUser(tb, "user-"+name) + } + + // Create pre-auth key. + uid := types.UserID(user.ID) + + var authKey string + if cc.ephemeral { + authKey = server.CreateEphemeralPreAuthKey(tb, uid) + } else { + authKey = server.CreatePreAuthKey(tb, uid) + } + + // Set up Tailscale client infrastructure. + bus := eventbus.New() + tracker := health.NewTracker(bus) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) + + machineKey := key.NewMachine() + + direct, err := controlclient.NewDirect(controlclient.Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { return machineKey, nil }, + ServerURL: server.URL, + AuthKey: authKey, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "servertest-" + name, + Hostname: cc.hostname, + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: tb.Logf, + HealthTracker: tracker, + Dialer: dialer, + Bus: bus, + }) + if err != nil { + tb.Fatalf("servertest: NewDirect(%s): %v", name, err) + } + + tc := &TestClient{ + Name: name, + server: server, + direct: direct, + authKey: authKey, + user: user, + updates: make(chan *netmap.NetworkMap, 64), + bus: bus, + dialer: dialer, + tracker: tracker, + } + + tb.Cleanup(func() { + tc.cleanup() + }) + + // Register with the server. + tc.register(tb) + + // Start long-polling in the background. + tc.startPoll(tb) + + return tc +} + +// register performs the initial TryLogin to register the client. +func (c *TestClient) register(tb testing.TB) { + tb.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + url, err := c.direct.TryLogin(ctx, controlclient.LoginDefault) + if err != nil { + tb.Fatalf("servertest: TryLogin(%s): %v", c.Name, err) + } + + if url != "" { + tb.Fatalf("servertest: TryLogin(%s): unexpected auth URL: %s (expected auto-auth with preauth key)", c.Name, url) + } +} + +// startPoll begins the long-poll MapRequest loop. +func (c *TestClient) startPoll(tb testing.TB) { + tb.Helper() + + c.pollCtx, c.pollCancel = context.WithCancel(context.Background()) + c.pollDone = make(chan struct{}) + + go func() { + defer close(c.pollDone) + // PollNetMap blocks until ctx is cancelled or the server closes + // the connection. + _ = c.direct.PollNetMap(c.pollCtx, c) + }() +} + +// UpdateFullNetmap implements controlclient.NetmapUpdater. +// Called by controlclient.Direct when a new NetworkMap is received. +func (c *TestClient) UpdateFullNetmap(nm *netmap.NetworkMap) { + c.mu.Lock() + c.netmap = nm + c.history = append(c.history, nm) + c.mu.Unlock() + + // Non-blocking send to the updates channel. + select { + case c.updates <- nm: + default: + } +} + +// cleanup releases all resources. +func (c *TestClient) cleanup() { + if c.pollCancel != nil { + c.pollCancel() + } + + if c.pollDone != nil { + // Wait for PollNetMap to exit, but don't hang. + select { + case <-c.pollDone: + case <-time.After(5 * time.Second): + } + } + + if c.direct != nil { + c.direct.Close() + } + + if c.dialer != nil { + c.dialer.Close() + } + + if c.bus != nil { + c.bus.Close() + } +} + +// --- Lifecycle methods --- + +// Disconnect cancels the long-poll context, simulating a clean +// client disconnect. +func (c *TestClient) Disconnect(tb testing.TB) { + tb.Helper() + + if c.pollCancel != nil { + c.pollCancel() + <-c.pollDone + } +} + +// Reconnect registers and starts a new long-poll session. +// Call Disconnect first, or this will disconnect automatically. +func (c *TestClient) Reconnect(tb testing.TB) { + tb.Helper() + + // Cancel any existing poll. + if c.pollCancel != nil { + c.pollCancel() + + select { + case <-c.pollDone: + case <-time.After(5 * time.Second): + tb.Fatalf("servertest: Reconnect(%s): old poll did not exit", c.Name) + } + } + + // Re-register and start polling again. + c.register(tb) + + c.startPoll(tb) +} + +// ReconnectAfter disconnects, waits for d, then reconnects. +// The timer works correctly with testing/synctest for +// time-controlled tests. +func (c *TestClient) ReconnectAfter(tb testing.TB, d time.Duration) { + tb.Helper() + c.Disconnect(tb) + + timer := time.NewTimer(d) + defer timer.Stop() + + <-timer.C + c.Reconnect(tb) +} + +// --- State accessors --- + +// Netmap returns the latest NetworkMap, or nil if none received yet. +func (c *TestClient) Netmap() *netmap.NetworkMap { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.netmap +} + +// WaitForPeers blocks until the client sees at least n peers, +// or until timeout expires. +func (c *TestClient) WaitForPeers(tb testing.TB, n int, timeout time.Duration) { + tb.Helper() + + deadline := time.After(timeout) + + for { + if nm := c.Netmap(); nm != nil && len(nm.Peers) >= n { + return + } + + select { + case <-c.updates: + // Check again. + case <-deadline: + nm := c.Netmap() + + got := 0 + if nm != nil { + got = len(nm.Peers) + } + + tb.Fatalf("servertest: WaitForPeers(%s, %d): timeout after %v (got %d peers)", c.Name, n, timeout, got) + } + } +} + +// WaitForUpdate blocks until the next netmap update arrives or timeout. +func (c *TestClient) WaitForUpdate(tb testing.TB, timeout time.Duration) *netmap.NetworkMap { + tb.Helper() + + select { + case nm := <-c.updates: + return nm + case <-time.After(timeout): + tb.Fatalf("servertest: WaitForUpdate(%s): timeout after %v", c.Name, timeout) + + return nil + } +} + +// Peers returns the current peer list, or nil. +func (c *TestClient) Peers() []tailcfg.NodeView { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.netmap == nil { + return nil + } + + return c.netmap.Peers +} + +// PeerByName finds a peer by hostname. Returns the peer and true +// if found, zero value and false otherwise. +func (c *TestClient) PeerByName(hostname string) (tailcfg.NodeView, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.netmap == nil { + return tailcfg.NodeView{}, false + } + + for _, p := range c.netmap.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == hostname { + return p, true + } + } + + return tailcfg.NodeView{}, false +} + +// PeerNames returns the hostnames of all current peers. +func (c *TestClient) PeerNames() []string { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.netmap == nil { + return nil + } + + names := make([]string, 0, len(c.netmap.Peers)) + for _, p := range c.netmap.Peers { + hi := p.Hostinfo() + if hi.Valid() { + names = append(names, hi.Hostname()) + } + } + + return names +} + +// UpdateCount returns the total number of full netmap updates received. +func (c *TestClient) UpdateCount() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.history) +} + +// History returns a copy of all NetworkMap snapshots in order. +func (c *TestClient) History() []*netmap.NetworkMap { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]*netmap.NetworkMap, len(c.history)) + copy(out, c.history) + + return out +} + +// SelfName returns the self node's hostname from the latest netmap. +func (c *TestClient) SelfName() string { + nm := c.Netmap() + if nm == nil || !nm.SelfNode.Valid() { + return "" + } + + return nm.SelfNode.Hostinfo().Hostname() +} + +// String implements fmt.Stringer for debug output. +func (c *TestClient) String() string { + nm := c.Netmap() + if nm == nil { + return fmt.Sprintf("TestClient(%s, no netmap)", c.Name) + } + + return fmt.Sprintf("TestClient(%s, %d peers)", c.Name, len(nm.Peers)) +} diff --git a/hscontrol/servertest/harness.go b/hscontrol/servertest/harness.go new file mode 100644 index 00000000..312d8290 --- /dev/null +++ b/hscontrol/servertest/harness.go @@ -0,0 +1,157 @@ +package servertest + +import ( + "fmt" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" +) + +// TestHarness orchestrates a TestServer with multiple TestClients, +// providing a convenient setup for multi-node control plane tests. +type TestHarness struct { + Server *TestServer + clients []*TestClient + + // Default user shared by all clients unless overridden. + defaultUser *types.User +} + +// HarnessOption configures a TestHarness. +type HarnessOption func(*harnessConfig) + +type harnessConfig struct { + serverOpts []ServerOption + clientOpts []ClientOption + convergenceMax time.Duration +} + +func defaultHarnessConfig() *harnessConfig { + return &harnessConfig{ + convergenceMax: 30 * time.Second, + } +} + +// WithServerOptions passes ServerOptions through to the underlying +// TestServer. +func WithServerOptions(opts ...ServerOption) HarnessOption { + return func(c *harnessConfig) { c.serverOpts = append(c.serverOpts, opts...) } +} + +// WithDefaultClientOptions applies ClientOptions to every client +// created by NewHarness. +func WithDefaultClientOptions(opts ...ClientOption) HarnessOption { + return func(c *harnessConfig) { c.clientOpts = append(c.clientOpts, opts...) } +} + +// WithConvergenceTimeout sets how long WaitForMeshComplete waits. +func WithConvergenceTimeout(d time.Duration) HarnessOption { + return func(c *harnessConfig) { c.convergenceMax = d } +} + +// NewHarness creates a TestServer and numClients connected clients. +// All clients share a default user and are registered with reusable +// pre-auth keys. The harness waits for all clients to form a +// complete mesh before returning. +func NewHarness(tb testing.TB, numClients int, opts ...HarnessOption) *TestHarness { + tb.Helper() + + hc := defaultHarnessConfig() + for _, o := range opts { + o(hc) + } + + server := NewServer(tb, hc.serverOpts...) + + // Create a shared default user. + user := server.CreateUser(tb, "harness-default") + + h := &TestHarness{ + Server: server, + defaultUser: user, + } + + // Create and connect clients. + for i := range numClients { + name := clientName(i) + + copts := append([]ClientOption{WithUser(user)}, hc.clientOpts...) + c := NewClient(tb, server, name, copts...) + h.clients = append(h.clients, c) + } + + // Wait for the mesh to converge. + if numClients > 1 { + h.WaitForMeshComplete(tb, hc.convergenceMax) + } else if numClients == 1 { + // Single node: just wait for the first netmap. + h.clients[0].WaitForUpdate(tb, hc.convergenceMax) + } + + return h +} + +// Client returns the i-th client (0-indexed). +func (h *TestHarness) Client(i int) *TestClient { + return h.clients[i] +} + +// Clients returns all clients. +func (h *TestHarness) Clients() []*TestClient { + return h.clients +} + +// ConnectedClients returns clients that currently have an active +// long-poll session (pollDone channel is still open). +func (h *TestHarness) ConnectedClients() []*TestClient { + var out []*TestClient + + for _, c := range h.clients { + select { + case <-c.pollDone: + // Poll has ended, client is disconnected. + default: + out = append(out, c) + } + } + + return out +} + +// AddClient creates and connects a new client to the existing mesh. +func (h *TestHarness) AddClient(tb testing.TB, opts ...ClientOption) *TestClient { + tb.Helper() + + name := clientName(len(h.clients)) + copts := append([]ClientOption{WithUser(h.defaultUser)}, opts...) + c := NewClient(tb, h.Server, name, copts...) + h.clients = append(h.clients, c) + + return c +} + +// WaitForMeshComplete blocks until every connected client sees +// (connectedCount - 1) peers. +func (h *TestHarness) WaitForMeshComplete(tb testing.TB, timeout time.Duration) { + tb.Helper() + + connected := h.ConnectedClients() + + expectedPeers := max(len(connected)-1, 0) + + for _, c := range connected { + c.WaitForPeers(tb, expectedPeers, timeout) + } +} + +// WaitForConvergence waits until all connected clients have a +// non-nil NetworkMap and their peer counts have stabilised. +func (h *TestHarness) WaitForConvergence(tb testing.TB, timeout time.Duration) { + tb.Helper() + h.WaitForMeshComplete(tb, timeout) +} + +func clientName(index int) string { + return fmt.Sprintf("node-%d", index) +} diff --git a/hscontrol/servertest/server.go b/hscontrol/servertest/server.go new file mode 100644 index 00000000..d9c0b85a --- /dev/null +++ b/hscontrol/servertest/server.go @@ -0,0 +1,182 @@ +// Package servertest provides an in-process test harness for Headscale's +// control plane. It wires a real Headscale server to real Tailscale +// controlclient.Direct instances, enabling fast, deterministic tests +// of the full control protocol without Docker or separate processes. +package servertest + +import ( + "net/http/httptest" + "testing" + "time" + + hscontrol "github.com/juanfont/headscale/hscontrol" + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +// TestServer is an in-process Headscale control server suitable for +// use with Tailscale's controlclient.Direct. +type TestServer struct { + App *hscontrol.Headscale + HTTPServer *httptest.Server + URL string + st *state.State +} + +// ServerOption configures a TestServer. +type ServerOption func(*serverConfig) + +type serverConfig struct { + batchDelay time.Duration + bufferedChanSize int + ephemeralTimeout time.Duration + batcherWorkers int +} + +func defaultServerConfig() *serverConfig { + return &serverConfig{ + batchDelay: 50 * time.Millisecond, + batcherWorkers: 1, + ephemeralTimeout: 30 * time.Second, + } +} + +// WithBatchDelay sets the batcher's change coalescing delay. +func WithBatchDelay(d time.Duration) ServerOption { + return func(c *serverConfig) { c.batchDelay = d } +} + +// WithBufferedChanSize sets the per-node map session channel buffer. +func WithBufferedChanSize(n int) ServerOption { + return func(c *serverConfig) { c.bufferedChanSize = n } +} + +// WithEphemeralTimeout sets the ephemeral node inactivity timeout. +func WithEphemeralTimeout(d time.Duration) ServerOption { + return func(c *serverConfig) { c.ephemeralTimeout = d } +} + +// NewServer creates and starts a Headscale test server. +// The server is fully functional and accepts real Tailscale control +// protocol connections over Noise. +func NewServer(tb testing.TB, opts ...ServerOption) *TestServer { + tb.Helper() + + sc := defaultServerConfig() + for _, o := range opts { + o(sc) + } + + tmpDir := tb.TempDir() + + cfg := types.Config{ + // Placeholder; updated below once httptest server starts. + ServerURL: "http://localhost:0", + NoisePrivateKeyPath: tmpDir + "/noise_private.key", + EphemeralNodeInactivityTimeout: sc.ephemeralTimeout, + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + Tuning: types.Tuning{ + BatchChangeDelay: sc.batchDelay, + BatcherWorkers: sc.batcherWorkers, + NodeMapSessionBufferedChanSize: sc.bufferedChanSize, + }, + } + + app, err := hscontrol.NewHeadscale(&cfg) + if err != nil { + tb.Fatalf("servertest: NewHeadscale: %v", err) + } + + // Set a minimal DERP map so MapResponse generation works. + app.GetState().SetDERPMap(&tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 900: { + RegionID: 900, + RegionCode: "test", + RegionName: "Test Region", + Nodes: []*tailcfg.DERPNode{{ + Name: "test0", + RegionID: 900, + HostName: "127.0.0.1", + IPv4: "127.0.0.1", + DERPPort: -1, // not a real DERP, just needed for MapResponse + }}, + }, + }, + }) + + // Start subsystems. + app.StartBatcherForTest(tb) + app.StartEphemeralGCForTest(tb) + + // Start the HTTP server with Headscale's full handler (including + // /key and /ts2021 Noise upgrade). + ts := httptest.NewServer(app.HTTPHandler()) + tb.Cleanup(ts.Close) + + // Now update the config to point at the real URL so that + // MapResponse.ControlURL etc. are correct. + app.SetServerURLForTest(tb, ts.URL) + + return &TestServer{ + App: app, + HTTPServer: ts, + URL: ts.URL, + st: app.GetState(), + } +} + +// State returns the server's state manager for creating users, +// nodes, and pre-auth keys. +func (s *TestServer) State() *state.State { + return s.st +} + +// CreateUser creates a test user and returns it. +func (s *TestServer) CreateUser(tb testing.TB, name string) *types.User { + tb.Helper() + + u, _, err := s.st.CreateUser(types.User{Name: name}) + if err != nil { + tb.Fatalf("servertest: CreateUser(%q): %v", name, err) + } + + return u +} + +// CreatePreAuthKey creates a reusable pre-auth key for the given user. +func (s *TestServer) CreatePreAuthKey(tb testing.TB, userID types.UserID) string { + tb.Helper() + + uid := userID + + pak, err := s.st.CreatePreAuthKey(&uid, true, false, nil, nil) + if err != nil { + tb.Fatalf("servertest: CreatePreAuthKey: %v", err) + } + + return pak.Key +} + +// CreateEphemeralPreAuthKey creates an ephemeral pre-auth key. +func (s *TestServer) CreateEphemeralPreAuthKey(tb testing.TB, userID types.UserID) string { + tb.Helper() + + uid := userID + + pak, err := s.st.CreatePreAuthKey(&uid, false, true, nil, nil) + if err != nil { + tb.Fatalf("servertest: CreateEphemeralPreAuthKey: %v", err) + } + + return pak.Key +}