From 99767cf805d13fd6a6a687ca966365fe0e98b747 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 9 Apr 2026 17:41:01 +0000 Subject: [PATCH] hscontrol: validate machine key and bind src/dst in SSH check handler SSHActionHandler now verifies that the Noise session's machine key matches the dst node before proceeding. The (src, dst) pair is captured at hold-and-delegate time via a new SSHCheckBinding on AuthRequest so sshActionFollowUp can verify the follow-up URL matches. The OIDC non-registration callback requires the authenticated user to own the src node before approving. --- hscontrol/noise.go | 92 +++++++++++++++++++++- hscontrol/noise_test.go | 138 +++++++++++++++++++++++++++++++++ hscontrol/oidc.go | 60 +++++++++++++- hscontrol/state/state.go | 9 +++ hscontrol/types/common.go | 69 ++++++++++++++--- hscontrol/types/common_test.go | 53 +++++++++++++ 6 files changed, 404 insertions(+), 17 deletions(-) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 51020f9f..366cc9f1 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -36,6 +36,28 @@ var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type" // ErrNoAuthSession is returned when an auth_id does not match any active auth session. var ErrNoAuthSession = errors.New("no auth session found") +// ErrSSHDstNodeNotFound is returned when the dst node id on a Noise SSH +// action request does not match any registered node. +var ErrSSHDstNodeNotFound = errors.New("ssh action: unknown dst node id") + +// ErrSSHMachineKeyMismatch is returned when the Noise session's machine +// key does not match the dst node referenced in the SSH action URL. +var ErrSSHMachineKeyMismatch = errors.New( + "ssh action: noise session machine key does not match dst node", +) + +// ErrSSHAuthSessionNotBound is returned when an SSH action follow-up +// references an auth session that is not bound to an SSH check pair. +var ErrSSHAuthSessionNotBound = errors.New( + "ssh action: cached auth session is not an SSH-check binding", +) + +// ErrSSHBindingMismatch is returned when an SSH action follow-up's +// (src, dst) pair does not match the cached binding for its auth_id. +var ErrSSHBindingMismatch = errors.New( + "ssh action: cached binding does not match request src/dst", +) + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -337,6 +359,37 @@ func (ns *noiseServer) SSHActionHandler( return } + // Authenticate the Noise session: the destination node is the + // tailscaled instance asking us whether to permit an incoming SSH + // connection, so its Noise session must belong to dst. Without this + // check any unauthenticated client could open a Noise tunnel with a + // throwaway machine key and pollute lastSSHAuth for arbitrary + // (src, dst) pairs, defeating SSH check-mode's stolen-key + // protections. + dstNode, ok := ns.headscale.state.GetNodeByID(dstNodeID) + if !ok { + httpError(writer, NewHTTPError( + http.StatusNotFound, + "dst node not found", + fmt.Errorf("%w: %d", ErrSSHDstNodeNotFound, dstNodeID), + )) + + return + } + + if dstNode.MachineKey() != ns.machineKey { + httpError(writer, NewHTTPError( + http.StatusUnauthorized, + "machine key does not match dst node", + fmt.Errorf( + "%w: machine key %s, dst node %d", + ErrSSHMachineKeyMismatch, ns.machineKey.ShortString(), dstNodeID, + ), + )) + + return + } + reqLog := log.With(). Uint64("src_node_id", srcNodeID.Uint64()). Uint64("dst_node_id", dstNodeID.Uint64()). @@ -426,14 +479,16 @@ func (ns *noiseServer) sshAction( } // No auto-approval — create an auth session and hold. - return ns.sshActionHoldAndDelegate(reqLog, &action) + return ns.sshActionHoldAndDelegate(reqLog, &action, srcNodeID, dstNodeID) } -// sshActionHoldAndDelegate creates a new auth session and returns a -// HoldAndDelegate action that directs the client to authenticate. +// sshActionHoldAndDelegate creates a new auth session bound to the +// (src, dst) pair and returns a HoldAndDelegate action that directs the +// client to authenticate. func (ns *noiseServer) sshActionHoldAndDelegate( reqLog zerolog.Logger, action *tailcfg.SSHAction, + srcNodeID, dstNodeID types.NodeID, ) (*tailcfg.SSHAction, error) { holdURL, err := url.Parse( ns.headscale.cfg.ServerURL + @@ -457,7 +512,10 @@ func (ns *noiseServer) sshActionHoldAndDelegate( ) } - ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest()) + ns.headscale.state.SetAuthCacheEntry( + authID, + types.NewSSHCheckAuthRequest(srcNodeID, dstNodeID), + ) authURL := ns.headscale.authProvider.AuthURL(authID) @@ -512,6 +570,32 @@ func (ns *noiseServer) sshActionFollowUp( ) } + // Verify the cached binding matches the (src, dst) pair the + // follow-up URL claims. Without this check an attacker who knew an + // auth_id could submit a follow-up for any other (src, dst) pair + // and have its verdict recorded against that pair instead. + if !auth.IsSSHCheck() { + return nil, NewHTTPError( + http.StatusBadRequest, + "auth session is not for SSH check", + fmt.Errorf("%w: %s", ErrSSHAuthSessionNotBound, authID), + ) + } + + binding := auth.SSHCheckBinding() + if binding.SrcNodeID != srcNodeID || binding.DstNodeID != dstNodeID { + return nil, NewHTTPError( + http.StatusUnauthorized, + "src/dst pair does not match auth session", + fmt.Errorf( + "%w: cached %d->%d, request %d->%d", + ErrSSHBindingMismatch, + binding.SrcNodeID, binding.DstNodeID, + srcNodeID, dstNodeID, + ), + ) + } + reqLog.Trace().Caller().Msg("SSH action follow-up") verdict := <-auth.WaitForAuth() diff --git a/hscontrol/noise_test.go b/hscontrol/noise_test.go index 594521f5..320869ca 100644 --- a/hscontrol/noise_test.go +++ b/hscontrol/noise_test.go @@ -4,15 +4,19 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" + "strconv" "testing" "github.com/go-chi/chi/v5" + "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) // newNoiseRouterWithBodyLimit builds a chi router with the same body-limit @@ -193,3 +197,137 @@ func TestRegistrationHandler_OversizedBody(t *testing.T) { // for version 0 → returns 400. assert.Equal(t, http.StatusBadRequest, rec.Code) } + +// newSSHActionRequest builds an httptest request with the chi URL params +// SSHActionHandler reads (src_node_id and dst_node_id), so the handler +// can be exercised directly without going through the chi router. +func newSSHActionRequest(t *testing.T, src, dst types.NodeID) *http.Request { + t.Helper() + + url := fmt.Sprintf("/machine/ssh/action/from/%d/to/%d", src.Uint64(), dst.Uint64()) + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("src_node_id", strconv.FormatUint(src.Uint64(), 10)) + rctx.URLParams.Add("dst_node_id", strconv.FormatUint(dst.Uint64(), 10)) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + return req +} + +// putTestNodeInStore creates a node via the database test helper and +// also stages it into the in-memory NodeStore so handlers that read +// NodeStore-backed APIs (e.g. State.GetNodeByID) can see it. +func putTestNodeInStore(t *testing.T, app *Headscale, user *types.User, hostname string) *types.Node { + t.Helper() + + node := app.state.CreateNodeForTest(user, hostname) + app.state.PutNodeInStoreForTest(*node) + + return node +} + +// TestSSHActionHandler_RejectsRogueMachineKey verifies that the SSH +// check action endpoint rejects a Noise session whose machine key does +// not match the dst node. +func TestSSHActionHandler_RejectsRogueMachineKey(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + user := app.state.CreateUserForTest("ssh-handler-user") + + src := putTestNodeInStore(t, app, user, "src-node") + dst := putTestNodeInStore(t, app, user, "dst-node") + + // noiseServer carries the wrong machine key — a fresh throwaway key, + // not dst.MachineKey. + rogue := key.NewMachine().Public() + require.NotEqual(t, dst.MachineKey, rogue, "test sanity: rogue key must differ from dst") + + ns := &noiseServer{ + headscale: app, + machineKey: rogue, + } + + rec := httptest.NewRecorder() + ns.SSHActionHandler(rec, newSSHActionRequest(t, src.ID, dst.ID)) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, + "rogue machine key must be rejected with 401") + + // And the auth cache must not have been mutated by the rejected request. + if last, ok := app.state.GetLastSSHAuth(src.ID, dst.ID); ok { + t.Fatalf("rejected SSH action must not record lastSSHAuth, got %v", last) + } +} + +// TestSSHActionHandler_RejectsUnknownDst verifies that the handler +// rejects a request for a dst_node_id that does not exist with 404. +func TestSSHActionHandler_RejectsUnknownDst(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + user := app.state.CreateUserForTest("ssh-handler-unknown-user") + src := putTestNodeInStore(t, app, user, "src-node") + + ns := &noiseServer{ + headscale: app, + machineKey: key.NewMachine().Public(), + } + + rec := httptest.NewRecorder() + ns.SSHActionHandler(rec, newSSHActionRequest(t, src.ID, 9999)) + + assert.Equal(t, http.StatusNotFound, rec.Code, + "unknown dst node id must be rejected with 404") +} + +// TestSSHActionFollowUp_RejectsBindingMismatch verifies that the +// follow-up handler refuses to honour an auth_id whose cached binding +// does not match the (src, dst) pair on the request URL. Without this +// check an attacker holding any auth_id could route its verdict to a +// different node pair. +func TestSSHActionFollowUp_RejectsBindingMismatch(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + user := app.state.CreateUserForTest("ssh-binding-user") + + srcCached := putTestNodeInStore(t, app, user, "src-cached") + dstCached := putTestNodeInStore(t, app, user, "dst-cached") + srcOther := putTestNodeInStore(t, app, user, "src-other") + dstOther := putTestNodeInStore(t, app, user, "dst-other") + + // Mint an SSH-check auth request bound to (srcCached, dstCached). + authID := types.MustAuthID() + app.state.SetAuthCacheEntry( + authID, + types.NewSSHCheckAuthRequest(srcCached.ID, dstCached.ID), + ) + + // Build a follow-up that claims to be for (srcOther, dstOther) but + // reuses the bound auth_id. The Noise machineKey matches dstOther so + // the outer machine-key check passes — only the binding check + // should reject it. + ns := &noiseServer{ + headscale: app, + machineKey: dstOther.MachineKey, + } + + url := fmt.Sprintf( + "/machine/ssh/action/from/%d/to/%d?auth_id=%s", + srcOther.ID.Uint64(), dstOther.ID.Uint64(), authID.String(), + ) + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("src_node_id", strconv.FormatUint(srcOther.ID.Uint64(), 10)) + rctx.URLParams.Add("dst_node_id", strconv.FormatUint(dstOther.ID.Uint64(), 10)) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + rec := httptest.NewRecorder() + ns.SSHActionHandler(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, + "binding mismatch must be rejected with 401") +} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 010995bc..cf5014fe 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -365,8 +365,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // If this is not a registration callback, then its a regular authentication callback - // and we need to send a response and confirm that the access was allowed. + // If this is not a registration callback, then it is an SSH + // check-mode auth callback. Confirm the OIDC identity is the owner + // of the SSH source node before recording approval; without this + // check any tailnet user could approve a check-mode prompt for any + // other user's node, defeating the stolen-key protection that + // check-mode is meant to provide. authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID) if !ok { @@ -376,7 +380,57 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Send a finish auth verdict with no errors to let the CLI know that the authentication was successful. + if !authReq.IsSSHCheck() { + log.Warn().Caller(). + Str("auth_id", authInfo.AuthID.String()). + Msg("OIDC callback hit non-registration path with auth request that is not an SSH check binding") + httpError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for SSH check", nil)) + + return + } + + binding := authReq.SSHCheckBinding() + + srcNode, ok := a.h.state.GetNodeByID(binding.SrcNodeID) + if !ok { + log.Warn().Caller(). + Str("auth_id", authInfo.AuthID.String()). + Uint64("src_node_id", binding.SrcNodeID.Uint64()). + Msg("SSH check src node no longer exists") + httpError(writer, NewHTTPError(http.StatusGone, "src node no longer exists", nil)) + + return + } + + // Strict identity binding: only the user that owns the src node + // may approve an SSH check for that node. Tagged source nodes are + // rejected because they have no user owner to compare against. + if srcNode.IsTagged() || !srcNode.UserID().Valid() { + log.Warn().Caller(). + Str("auth_id", authInfo.AuthID.String()). + Uint64("src_node_id", binding.SrcNodeID.Uint64()). + Bool("src_is_tagged", srcNode.IsTagged()). + Str("oidc_user", user.Username()). + Msg("SSH check rejected: src node has no user owner") + httpError(writer, NewHTTPError(http.StatusForbidden, "src node has no user owner", nil)) + + return + } + + if srcNode.UserID().Get() != user.ID { + log.Warn().Caller(). + Str("auth_id", authInfo.AuthID.String()). + Uint64("src_node_id", binding.SrcNodeID.Uint64()). + Uint("src_owner_id", srcNode.UserID().Get()). + Uint("oidc_user_id", user.ID). + Str("oidc_user", user.Username()). + Msg("SSH check rejected: OIDC user is not the owner of src node") + httpError(writer, NewHTTPError(http.StatusForbidden, "OIDC user is not the owner of the SSH source node", nil)) + + return + } + + // Identity verified — record the verdict for the waiting follow-up. authReq.FinishAuth(types.AuthVerdict{}) content := renderAuthSuccessTemplate(user) diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d96223f5..39d7e36a 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -1214,6 +1214,15 @@ func (s *State) CreateNodeForTest(user *types.User, hostname ...string) *types.N return s.db.CreateNodeForTest(user, hostname...) } +// PutNodeInStoreForTest writes a test node into the in-memory NodeStore +// so handlers backed by NodeStore lookups (e.g. GetNodeByID) can see it. +// CreateNodeForTest only saves to the database, which is fine for tests +// that exercise the DB layer directly but insufficient for handler tests +// that go through State. +func (s *State) PutNodeInStoreForTest(node types.Node) types.NodeView { + return s.nodeStore.PutNode(node) +} + // CreateRegisteredNodeForTest creates a test node with allocated IPs. This is a convenience wrapper around the database layer. func (s *State) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node { return s.db.CreateRegisteredNodeForTest(user, hostname...) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 1cc1b08a..0e17c5a8 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -221,26 +221,42 @@ func (r AuthID) Validate() error { return nil } +// SSHCheckBinding identifies the (source, destination) node pair an SSH +// check-mode auth request is bound to. It is captured at HoldAndDelegate +// time so the follow-up request and OIDC callback can verify that no +// other (src, dst) pair has been substituted via tampered URL parameters. +type SSHCheckBinding struct { + SrcNodeID NodeID + DstNodeID NodeID +} + // AuthRequest represents a pending authentication request from a user or a // node. It carries the minimum data needed to either complete a node -// registration (regData populated) or signal the verdict of an interactive -// auth flow (no payload). Verdict delivery is via the finished channel; the -// closed flag guards FinishAuth against double-close. +// registration (regData populated) or an SSH check-mode auth (sshBinding +// populated), and signals the verdict via the finished channel. The closed +// flag guards FinishAuth against double-close. // // AuthRequest is always handled by pointer so the channel and atomic flag // have a single canonical instance even when stored in caches that // internally copy values. type AuthRequest struct { // regData is populated for node-registration flows (interactive web - // or OIDC). It carries only the minimal subset of registration data - // the auth callback needs to promote this request into a real node; - // see RegistrationData for the rationale behind keeping the payload - // small. + // or OIDC). It carries the cached registration payload that the + // auth callback uses to promote this request into a real node. // - // nil for non-registration flows (e.g. SSH check). Use - // RegistrationData() to read it safely. + // nil for non-registration flows. Use RegistrationData() to read it + // safely. regData *RegistrationData + // sshBinding is populated for SSH check-mode flows. It captures the + // (src, dst) node pair the request was minted for so the follow-up + // and OIDC callback can refuse to record a verdict for any other + // pair. + // + // nil for non-SSH-check flows. Use SSHCheckBinding() to read it + // safely. + sshBinding *SSHCheckBinding + finished chan AuthVerdict closed *atomic.Bool } @@ -265,9 +281,24 @@ func NewRegisterAuthRequest(data *RegistrationData) *AuthRequest { } } +// NewSSHCheckAuthRequest creates a pending auth request bound to a +// specific (src, dst) SSH check-mode pair. The follow-up handler and +// OIDC callback must verify their incoming request matches this binding +// before recording any verdict. +func NewSSHCheckAuthRequest(src, dst NodeID) *AuthRequest { + return &AuthRequest{ + sshBinding: &SSHCheckBinding{ + SrcNodeID: src, + DstNodeID: dst, + }, + finished: make(chan AuthVerdict, 1), + closed: &atomic.Bool{}, + } +} + // RegistrationData returns the cached registration payload. It panics if // called on an AuthRequest that was not created via -// NewRegisterAuthRequest, mirroring the previous Node() contract. +// NewRegisterAuthRequest. func (rn *AuthRequest) RegistrationData() *RegistrationData { if rn.regData == nil { panic("RegistrationData can only be used in registration requests") @@ -276,12 +307,30 @@ func (rn *AuthRequest) RegistrationData() *RegistrationData { return rn.regData } +// SSHCheckBinding returns the (src, dst) node pair an SSH check-mode +// auth request is bound to. It panics if called on an AuthRequest that +// was not created via NewSSHCheckAuthRequest. +func (rn *AuthRequest) SSHCheckBinding() *SSHCheckBinding { + if rn.sshBinding == nil { + panic("SSHCheckBinding can only be used in SSH check-mode requests") + } + + return rn.sshBinding +} + // IsRegistration reports whether this auth request carries registration // data (i.e. it was created via NewRegisterAuthRequest). func (rn *AuthRequest) IsRegistration() bool { return rn.regData != nil } +// IsSSHCheck reports whether this auth request is bound to an SSH +// check-mode (src, dst) pair (i.e. it was created via +// NewSSHCheckAuthRequest). +func (rn *AuthRequest) IsSSHCheck() bool { + return rn.sshBinding != nil +} + func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { if rn.closed.Swap(true) { return diff --git a/hscontrol/types/common_test.go b/hscontrol/types/common_test.go index a443918b..9ccb0145 100644 --- a/hscontrol/types/common_test.go +++ b/hscontrol/types/common_test.go @@ -2,8 +2,61 @@ package types import ( "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// TestNewSSHCheckAuthRequestBinding verifies that an SSH-check AuthRequest +// captures the (src, dst) node pair at construction time and rejects +// callers that try to read RegistrationData from it. +func TestNewSSHCheckAuthRequestBinding(t *testing.T) { + const src, dst NodeID = 7, 11 + + req := NewSSHCheckAuthRequest(src, dst) + + require.True(t, req.IsSSHCheck(), "SSH-check request must report IsSSHCheck=true") + require.False(t, req.IsRegistration(), "SSH-check request must not report IsRegistration") + + binding := req.SSHCheckBinding() + assert.Equal(t, src, binding.SrcNodeID, "SrcNodeID must match") + assert.Equal(t, dst, binding.DstNodeID, "DstNodeID must match") + + assert.Panics(t, func() { + _ = req.RegistrationData() + }, "RegistrationData() must panic on an SSH-check AuthRequest") +} + +// TestNewRegisterAuthRequestPayload verifies that a registration +// AuthRequest carries the supplied RegistrationData and rejects callers +// that try to read SSH-check binding from it. +func TestNewRegisterAuthRequestPayload(t *testing.T) { + data := &RegistrationData{Hostname: "node-a"} + + req := NewRegisterAuthRequest(data) + + require.True(t, req.IsRegistration(), "registration request must report IsRegistration=true") + require.False(t, req.IsSSHCheck(), "registration request must not report IsSSHCheck") + assert.Same(t, data, req.RegistrationData(), "RegistrationData() must return the supplied pointer") + + assert.Panics(t, func() { + _ = req.SSHCheckBinding() + }, "SSHCheckBinding() must panic on a registration AuthRequest") +} + +// TestNewAuthRequestEmptyPayload verifies that a payload-less +// AuthRequest reports both Is* helpers as false and panics on either +// payload accessor. +func TestNewAuthRequestEmptyPayload(t *testing.T) { + req := NewAuthRequest() + + assert.False(t, req.IsRegistration()) + assert.False(t, req.IsSSHCheck()) + + assert.Panics(t, func() { _ = req.RegistrationData() }) + assert.Panics(t, func() { _ = req.SSHCheckBinding() }) +} + func TestDefaultBatcherWorkersFor(t *testing.T) { tests := []struct { cpuCount int