mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-10 19:17:25 +02:00
hscontrol: validate machine key and bind src/dst in SSH check handler
SSHActionHandler now verifies that the Noise session's machine key matches the dst node before proceeding. The (src, dst) pair is captured at hold-and-delegate time via a new SSHCheckBinding on AuthRequest so sshActionFollowUp can verify the follow-up URL matches. The OIDC non-registration callback requires the authenticated user to own the src node before approving.
This commit is contained in:
@@ -36,6 +36,28 @@ var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type"
|
||||
// ErrNoAuthSession is returned when an auth_id does not match any active auth session.
|
||||
var ErrNoAuthSession = errors.New("no auth session found")
|
||||
|
||||
// ErrSSHDstNodeNotFound is returned when the dst node id on a Noise SSH
|
||||
// action request does not match any registered node.
|
||||
var ErrSSHDstNodeNotFound = errors.New("ssh action: unknown dst node id")
|
||||
|
||||
// ErrSSHMachineKeyMismatch is returned when the Noise session's machine
|
||||
// key does not match the dst node referenced in the SSH action URL.
|
||||
var ErrSSHMachineKeyMismatch = errors.New(
|
||||
"ssh action: noise session machine key does not match dst node",
|
||||
)
|
||||
|
||||
// ErrSSHAuthSessionNotBound is returned when an SSH action follow-up
|
||||
// references an auth session that is not bound to an SSH check pair.
|
||||
var ErrSSHAuthSessionNotBound = errors.New(
|
||||
"ssh action: cached auth session is not an SSH-check binding",
|
||||
)
|
||||
|
||||
// ErrSSHBindingMismatch is returned when an SSH action follow-up's
|
||||
// (src, dst) pair does not match the cached binding for its auth_id.
|
||||
var ErrSSHBindingMismatch = errors.New(
|
||||
"ssh action: cached binding does not match request src/dst",
|
||||
)
|
||||
|
||||
const (
|
||||
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
||||
ts2021UpgradePath = "/ts2021"
|
||||
@@ -337,6 +359,37 @@ func (ns *noiseServer) SSHActionHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate the Noise session: the destination node is the
|
||||
// tailscaled instance asking us whether to permit an incoming SSH
|
||||
// connection, so its Noise session must belong to dst. Without this
|
||||
// check any unauthenticated client could open a Noise tunnel with a
|
||||
// throwaway machine key and pollute lastSSHAuth for arbitrary
|
||||
// (src, dst) pairs, defeating SSH check-mode's stolen-key
|
||||
// protections.
|
||||
dstNode, ok := ns.headscale.state.GetNodeByID(dstNodeID)
|
||||
if !ok {
|
||||
httpError(writer, NewHTTPError(
|
||||
http.StatusNotFound,
|
||||
"dst node not found",
|
||||
fmt.Errorf("%w: %d", ErrSSHDstNodeNotFound, dstNodeID),
|
||||
))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if dstNode.MachineKey() != ns.machineKey {
|
||||
httpError(writer, NewHTTPError(
|
||||
http.StatusUnauthorized,
|
||||
"machine key does not match dst node",
|
||||
fmt.Errorf(
|
||||
"%w: machine key %s, dst node %d",
|
||||
ErrSSHMachineKeyMismatch, ns.machineKey.ShortString(), dstNodeID,
|
||||
),
|
||||
))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := log.With().
|
||||
Uint64("src_node_id", srcNodeID.Uint64()).
|
||||
Uint64("dst_node_id", dstNodeID.Uint64()).
|
||||
@@ -426,14 +479,16 @@ func (ns *noiseServer) sshAction(
|
||||
}
|
||||
|
||||
// No auto-approval — create an auth session and hold.
|
||||
return ns.sshActionHoldAndDelegate(reqLog, &action)
|
||||
return ns.sshActionHoldAndDelegate(reqLog, &action, srcNodeID, dstNodeID)
|
||||
}
|
||||
|
||||
// sshActionHoldAndDelegate creates a new auth session and returns a
|
||||
// HoldAndDelegate action that directs the client to authenticate.
|
||||
// sshActionHoldAndDelegate creates a new auth session bound to the
|
||||
// (src, dst) pair and returns a HoldAndDelegate action that directs the
|
||||
// client to authenticate.
|
||||
func (ns *noiseServer) sshActionHoldAndDelegate(
|
||||
reqLog zerolog.Logger,
|
||||
action *tailcfg.SSHAction,
|
||||
srcNodeID, dstNodeID types.NodeID,
|
||||
) (*tailcfg.SSHAction, error) {
|
||||
holdURL, err := url.Parse(
|
||||
ns.headscale.cfg.ServerURL +
|
||||
@@ -457,7 +512,10 @@ func (ns *noiseServer) sshActionHoldAndDelegate(
|
||||
)
|
||||
}
|
||||
|
||||
ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest())
|
||||
ns.headscale.state.SetAuthCacheEntry(
|
||||
authID,
|
||||
types.NewSSHCheckAuthRequest(srcNodeID, dstNodeID),
|
||||
)
|
||||
|
||||
authURL := ns.headscale.authProvider.AuthURL(authID)
|
||||
|
||||
@@ -512,6 +570,32 @@ func (ns *noiseServer) sshActionFollowUp(
|
||||
)
|
||||
}
|
||||
|
||||
// Verify the cached binding matches the (src, dst) pair the
|
||||
// follow-up URL claims. Without this check an attacker who knew an
|
||||
// auth_id could submit a follow-up for any other (src, dst) pair
|
||||
// and have its verdict recorded against that pair instead.
|
||||
if !auth.IsSSHCheck() {
|
||||
return nil, NewHTTPError(
|
||||
http.StatusBadRequest,
|
||||
"auth session is not for SSH check",
|
||||
fmt.Errorf("%w: %s", ErrSSHAuthSessionNotBound, authID),
|
||||
)
|
||||
}
|
||||
|
||||
binding := auth.SSHCheckBinding()
|
||||
if binding.SrcNodeID != srcNodeID || binding.DstNodeID != dstNodeID {
|
||||
return nil, NewHTTPError(
|
||||
http.StatusUnauthorized,
|
||||
"src/dst pair does not match auth session",
|
||||
fmt.Errorf(
|
||||
"%w: cached %d->%d, request %d->%d",
|
||||
ErrSSHBindingMismatch,
|
||||
binding.SrcNodeID, binding.DstNodeID,
|
||||
srcNodeID, dstNodeID,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
reqLog.Trace().Caller().Msg("SSH action follow-up")
|
||||
|
||||
verdict := <-auth.WaitForAuth()
|
||||
|
||||
@@ -4,15 +4,19 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// newNoiseRouterWithBodyLimit builds a chi router with the same body-limit
|
||||
@@ -193,3 +197,137 @@ func TestRegistrationHandler_OversizedBody(t *testing.T) {
|
||||
// for version 0 → returns 400.
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// newSSHActionRequest builds an httptest request with the chi URL params
|
||||
// SSHActionHandler reads (src_node_id and dst_node_id), so the handler
|
||||
// can be exercised directly without going through the chi router.
|
||||
func newSSHActionRequest(t *testing.T, src, dst types.NodeID) *http.Request {
|
||||
t.Helper()
|
||||
|
||||
url := fmt.Sprintf("/machine/ssh/action/from/%d/to/%d", src.Uint64(), dst.Uint64())
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("src_node_id", strconv.FormatUint(src.Uint64(), 10))
|
||||
rctx.URLParams.Add("dst_node_id", strconv.FormatUint(dst.Uint64(), 10))
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
// putTestNodeInStore creates a node via the database test helper and
|
||||
// also stages it into the in-memory NodeStore so handlers that read
|
||||
// NodeStore-backed APIs (e.g. State.GetNodeByID) can see it.
|
||||
func putTestNodeInStore(t *testing.T, app *Headscale, user *types.User, hostname string) *types.Node {
|
||||
t.Helper()
|
||||
|
||||
node := app.state.CreateNodeForTest(user, hostname)
|
||||
app.state.PutNodeInStoreForTest(*node)
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
// TestSSHActionHandler_RejectsRogueMachineKey verifies that the SSH
|
||||
// check action endpoint rejects a Noise session whose machine key does
|
||||
// not match the dst node.
|
||||
func TestSSHActionHandler_RejectsRogueMachineKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := createTestApp(t)
|
||||
user := app.state.CreateUserForTest("ssh-handler-user")
|
||||
|
||||
src := putTestNodeInStore(t, app, user, "src-node")
|
||||
dst := putTestNodeInStore(t, app, user, "dst-node")
|
||||
|
||||
// noiseServer carries the wrong machine key — a fresh throwaway key,
|
||||
// not dst.MachineKey.
|
||||
rogue := key.NewMachine().Public()
|
||||
require.NotEqual(t, dst.MachineKey, rogue, "test sanity: rogue key must differ from dst")
|
||||
|
||||
ns := &noiseServer{
|
||||
headscale: app,
|
||||
machineKey: rogue,
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ns.SSHActionHandler(rec, newSSHActionRequest(t, src.ID, dst.ID))
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code,
|
||||
"rogue machine key must be rejected with 401")
|
||||
|
||||
// And the auth cache must not have been mutated by the rejected request.
|
||||
if last, ok := app.state.GetLastSSHAuth(src.ID, dst.ID); ok {
|
||||
t.Fatalf("rejected SSH action must not record lastSSHAuth, got %v", last)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHActionHandler_RejectsUnknownDst verifies that the handler
|
||||
// rejects a request for a dst_node_id that does not exist with 404.
|
||||
func TestSSHActionHandler_RejectsUnknownDst(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := createTestApp(t)
|
||||
user := app.state.CreateUserForTest("ssh-handler-unknown-user")
|
||||
src := putTestNodeInStore(t, app, user, "src-node")
|
||||
|
||||
ns := &noiseServer{
|
||||
headscale: app,
|
||||
machineKey: key.NewMachine().Public(),
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ns.SSHActionHandler(rec, newSSHActionRequest(t, src.ID, 9999))
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code,
|
||||
"unknown dst node id must be rejected with 404")
|
||||
}
|
||||
|
||||
// TestSSHActionFollowUp_RejectsBindingMismatch verifies that the
|
||||
// follow-up handler refuses to honour an auth_id whose cached binding
|
||||
// does not match the (src, dst) pair on the request URL. Without this
|
||||
// check an attacker holding any auth_id could route its verdict to a
|
||||
// different node pair.
|
||||
func TestSSHActionFollowUp_RejectsBindingMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := createTestApp(t)
|
||||
user := app.state.CreateUserForTest("ssh-binding-user")
|
||||
|
||||
srcCached := putTestNodeInStore(t, app, user, "src-cached")
|
||||
dstCached := putTestNodeInStore(t, app, user, "dst-cached")
|
||||
srcOther := putTestNodeInStore(t, app, user, "src-other")
|
||||
dstOther := putTestNodeInStore(t, app, user, "dst-other")
|
||||
|
||||
// Mint an SSH-check auth request bound to (srcCached, dstCached).
|
||||
authID := types.MustAuthID()
|
||||
app.state.SetAuthCacheEntry(
|
||||
authID,
|
||||
types.NewSSHCheckAuthRequest(srcCached.ID, dstCached.ID),
|
||||
)
|
||||
|
||||
// Build a follow-up that claims to be for (srcOther, dstOther) but
|
||||
// reuses the bound auth_id. The Noise machineKey matches dstOther so
|
||||
// the outer machine-key check passes — only the binding check
|
||||
// should reject it.
|
||||
ns := &noiseServer{
|
||||
headscale: app,
|
||||
machineKey: dstOther.MachineKey,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"/machine/ssh/action/from/%d/to/%d?auth_id=%s",
|
||||
srcOther.ID.Uint64(), dstOther.ID.Uint64(), authID.String(),
|
||||
)
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("src_node_id", strconv.FormatUint(srcOther.ID.Uint64(), 10))
|
||||
rctx.URLParams.Add("dst_node_id", strconv.FormatUint(dstOther.ID.Uint64(), 10))
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ns.SSHActionHandler(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code,
|
||||
"binding mismatch must be rejected with 401")
|
||||
}
|
||||
|
||||
@@ -365,8 +365,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// If this is not a registration callback, then its a regular authentication callback
|
||||
// and we need to send a response and confirm that the access was allowed.
|
||||
// If this is not a registration callback, then it is an SSH
|
||||
// check-mode auth callback. Confirm the OIDC identity is the owner
|
||||
// of the SSH source node before recording approval; without this
|
||||
// check any tailnet user could approve a check-mode prompt for any
|
||||
// other user's node, defeating the stolen-key protection that
|
||||
// check-mode is meant to provide.
|
||||
|
||||
authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID)
|
||||
if !ok {
|
||||
@@ -376,7 +380,57 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// Send a finish auth verdict with no errors to let the CLI know that the authentication was successful.
|
||||
if !authReq.IsSSHCheck() {
|
||||
log.Warn().Caller().
|
||||
Str("auth_id", authInfo.AuthID.String()).
|
||||
Msg("OIDC callback hit non-registration path with auth request that is not an SSH check binding")
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for SSH check", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
binding := authReq.SSHCheckBinding()
|
||||
|
||||
srcNode, ok := a.h.state.GetNodeByID(binding.SrcNodeID)
|
||||
if !ok {
|
||||
log.Warn().Caller().
|
||||
Str("auth_id", authInfo.AuthID.String()).
|
||||
Uint64("src_node_id", binding.SrcNodeID.Uint64()).
|
||||
Msg("SSH check src node no longer exists")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "src node no longer exists", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Strict identity binding: only the user that owns the src node
|
||||
// may approve an SSH check for that node. Tagged source nodes are
|
||||
// rejected because they have no user owner to compare against.
|
||||
if srcNode.IsTagged() || !srcNode.UserID().Valid() {
|
||||
log.Warn().Caller().
|
||||
Str("auth_id", authInfo.AuthID.String()).
|
||||
Uint64("src_node_id", binding.SrcNodeID.Uint64()).
|
||||
Bool("src_is_tagged", srcNode.IsTagged()).
|
||||
Str("oidc_user", user.Username()).
|
||||
Msg("SSH check rejected: src node has no user owner")
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "src node has no user owner", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if srcNode.UserID().Get() != user.ID {
|
||||
log.Warn().Caller().
|
||||
Str("auth_id", authInfo.AuthID.String()).
|
||||
Uint64("src_node_id", binding.SrcNodeID.Uint64()).
|
||||
Uint("src_owner_id", srcNode.UserID().Get()).
|
||||
Uint("oidc_user_id", user.ID).
|
||||
Str("oidc_user", user.Username()).
|
||||
Msg("SSH check rejected: OIDC user is not the owner of src node")
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "OIDC user is not the owner of the SSH source node", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Identity verified — record the verdict for the waiting follow-up.
|
||||
authReq.FinishAuth(types.AuthVerdict{})
|
||||
|
||||
content := renderAuthSuccessTemplate(user)
|
||||
|
||||
@@ -1214,6 +1214,15 @@ func (s *State) CreateNodeForTest(user *types.User, hostname ...string) *types.N
|
||||
return s.db.CreateNodeForTest(user, hostname...)
|
||||
}
|
||||
|
||||
// PutNodeInStoreForTest writes a test node into the in-memory NodeStore
|
||||
// so handlers backed by NodeStore lookups (e.g. GetNodeByID) can see it.
|
||||
// CreateNodeForTest only saves to the database, which is fine for tests
|
||||
// that exercise the DB layer directly but insufficient for handler tests
|
||||
// that go through State.
|
||||
func (s *State) PutNodeInStoreForTest(node types.Node) types.NodeView {
|
||||
return s.nodeStore.PutNode(node)
|
||||
}
|
||||
|
||||
// CreateRegisteredNodeForTest creates a test node with allocated IPs. This is a convenience wrapper around the database layer.
|
||||
func (s *State) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
return s.db.CreateRegisteredNodeForTest(user, hostname...)
|
||||
|
||||
@@ -221,26 +221,42 @@ func (r AuthID) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHCheckBinding identifies the (source, destination) node pair an SSH
|
||||
// check-mode auth request is bound to. It is captured at HoldAndDelegate
|
||||
// time so the follow-up request and OIDC callback can verify that no
|
||||
// other (src, dst) pair has been substituted via tampered URL parameters.
|
||||
type SSHCheckBinding struct {
|
||||
SrcNodeID NodeID
|
||||
DstNodeID NodeID
|
||||
}
|
||||
|
||||
// AuthRequest represents a pending authentication request from a user or a
|
||||
// node. It carries the minimum data needed to either complete a node
|
||||
// registration (regData populated) or signal the verdict of an interactive
|
||||
// auth flow (no payload). Verdict delivery is via the finished channel; the
|
||||
// closed flag guards FinishAuth against double-close.
|
||||
// registration (regData populated) or an SSH check-mode auth (sshBinding
|
||||
// populated), and signals the verdict via the finished channel. The closed
|
||||
// flag guards FinishAuth against double-close.
|
||||
//
|
||||
// AuthRequest is always handled by pointer so the channel and atomic flag
|
||||
// have a single canonical instance even when stored in caches that
|
||||
// internally copy values.
|
||||
type AuthRequest struct {
|
||||
// regData is populated for node-registration flows (interactive web
|
||||
// or OIDC). It carries only the minimal subset of registration data
|
||||
// the auth callback needs to promote this request into a real node;
|
||||
// see RegistrationData for the rationale behind keeping the payload
|
||||
// small.
|
||||
// or OIDC). It carries the cached registration payload that the
|
||||
// auth callback uses to promote this request into a real node.
|
||||
//
|
||||
// nil for non-registration flows (e.g. SSH check). Use
|
||||
// RegistrationData() to read it safely.
|
||||
// nil for non-registration flows. Use RegistrationData() to read it
|
||||
// safely.
|
||||
regData *RegistrationData
|
||||
|
||||
// sshBinding is populated for SSH check-mode flows. It captures the
|
||||
// (src, dst) node pair the request was minted for so the follow-up
|
||||
// and OIDC callback can refuse to record a verdict for any other
|
||||
// pair.
|
||||
//
|
||||
// nil for non-SSH-check flows. Use SSHCheckBinding() to read it
|
||||
// safely.
|
||||
sshBinding *SSHCheckBinding
|
||||
|
||||
finished chan AuthVerdict
|
||||
closed *atomic.Bool
|
||||
}
|
||||
@@ -265,9 +281,24 @@ func NewRegisterAuthRequest(data *RegistrationData) *AuthRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// NewSSHCheckAuthRequest creates a pending auth request bound to a
|
||||
// specific (src, dst) SSH check-mode pair. The follow-up handler and
|
||||
// OIDC callback must verify their incoming request matches this binding
|
||||
// before recording any verdict.
|
||||
func NewSSHCheckAuthRequest(src, dst NodeID) *AuthRequest {
|
||||
return &AuthRequest{
|
||||
sshBinding: &SSHCheckBinding{
|
||||
SrcNodeID: src,
|
||||
DstNodeID: dst,
|
||||
},
|
||||
finished: make(chan AuthVerdict, 1),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// RegistrationData returns the cached registration payload. It panics if
|
||||
// called on an AuthRequest that was not created via
|
||||
// NewRegisterAuthRequest, mirroring the previous Node() contract.
|
||||
// NewRegisterAuthRequest.
|
||||
func (rn *AuthRequest) RegistrationData() *RegistrationData {
|
||||
if rn.regData == nil {
|
||||
panic("RegistrationData can only be used in registration requests")
|
||||
@@ -276,12 +307,30 @@ func (rn *AuthRequest) RegistrationData() *RegistrationData {
|
||||
return rn.regData
|
||||
}
|
||||
|
||||
// SSHCheckBinding returns the (src, dst) node pair an SSH check-mode
|
||||
// auth request is bound to. It panics if called on an AuthRequest that
|
||||
// was not created via NewSSHCheckAuthRequest.
|
||||
func (rn *AuthRequest) SSHCheckBinding() *SSHCheckBinding {
|
||||
if rn.sshBinding == nil {
|
||||
panic("SSHCheckBinding can only be used in SSH check-mode requests")
|
||||
}
|
||||
|
||||
return rn.sshBinding
|
||||
}
|
||||
|
||||
// IsRegistration reports whether this auth request carries registration
|
||||
// data (i.e. it was created via NewRegisterAuthRequest).
|
||||
func (rn *AuthRequest) IsRegistration() bool {
|
||||
return rn.regData != nil
|
||||
}
|
||||
|
||||
// IsSSHCheck reports whether this auth request is bound to an SSH
|
||||
// check-mode (src, dst) pair (i.e. it was created via
|
||||
// NewSSHCheckAuthRequest).
|
||||
func (rn *AuthRequest) IsSSHCheck() bool {
|
||||
return rn.sshBinding != nil
|
||||
}
|
||||
|
||||
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
|
||||
if rn.closed.Swap(true) {
|
||||
return
|
||||
|
||||
@@ -2,8 +2,61 @@ package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewSSHCheckAuthRequestBinding verifies that an SSH-check AuthRequest
|
||||
// captures the (src, dst) node pair at construction time and rejects
|
||||
// callers that try to read RegistrationData from it.
|
||||
func TestNewSSHCheckAuthRequestBinding(t *testing.T) {
|
||||
const src, dst NodeID = 7, 11
|
||||
|
||||
req := NewSSHCheckAuthRequest(src, dst)
|
||||
|
||||
require.True(t, req.IsSSHCheck(), "SSH-check request must report IsSSHCheck=true")
|
||||
require.False(t, req.IsRegistration(), "SSH-check request must not report IsRegistration")
|
||||
|
||||
binding := req.SSHCheckBinding()
|
||||
assert.Equal(t, src, binding.SrcNodeID, "SrcNodeID must match")
|
||||
assert.Equal(t, dst, binding.DstNodeID, "DstNodeID must match")
|
||||
|
||||
assert.Panics(t, func() {
|
||||
_ = req.RegistrationData()
|
||||
}, "RegistrationData() must panic on an SSH-check AuthRequest")
|
||||
}
|
||||
|
||||
// TestNewRegisterAuthRequestPayload verifies that a registration
|
||||
// AuthRequest carries the supplied RegistrationData and rejects callers
|
||||
// that try to read SSH-check binding from it.
|
||||
func TestNewRegisterAuthRequestPayload(t *testing.T) {
|
||||
data := &RegistrationData{Hostname: "node-a"}
|
||||
|
||||
req := NewRegisterAuthRequest(data)
|
||||
|
||||
require.True(t, req.IsRegistration(), "registration request must report IsRegistration=true")
|
||||
require.False(t, req.IsSSHCheck(), "registration request must not report IsSSHCheck")
|
||||
assert.Same(t, data, req.RegistrationData(), "RegistrationData() must return the supplied pointer")
|
||||
|
||||
assert.Panics(t, func() {
|
||||
_ = req.SSHCheckBinding()
|
||||
}, "SSHCheckBinding() must panic on a registration AuthRequest")
|
||||
}
|
||||
|
||||
// TestNewAuthRequestEmptyPayload verifies that a payload-less
|
||||
// AuthRequest reports both Is* helpers as false and panics on either
|
||||
// payload accessor.
|
||||
func TestNewAuthRequestEmptyPayload(t *testing.T) {
|
||||
req := NewAuthRequest()
|
||||
|
||||
assert.False(t, req.IsRegistration())
|
||||
assert.False(t, req.IsSSHCheck())
|
||||
|
||||
assert.Panics(t, func() { _ = req.RegistrationData() })
|
||||
assert.Panics(t, func() { _ = req.SSHCheckBinding() })
|
||||
}
|
||||
|
||||
func TestDefaultBatcherWorkersFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
cpuCount int
|
||||
|
||||
Reference in New Issue
Block a user