policy, noise: implement SSH check action

Implement the SSH "check" action which requires additional
verification before allowing SSH access. The policy compiler generates
a HoldAndDelegate URL that the Tailscale client calls back to
headscale. The SSHActionHandler creates an auth session and waits for
approval via the generalised auth flow.

Sort check (HoldAndDelegate) rules before accept rules to match
Tailscale's first-match-wins evaluation order.

Updates #1850
This commit is contained in:
Kristoffer Dalby
2026-02-24 18:50:18 +00:00
parent 4a7e1475c0
commit 107c2f2f70
10 changed files with 500 additions and 71 deletions

View File

@@ -7,12 +7,14 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
"github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"tailscale.com/control/controlbase"
@@ -30,6 +32,9 @@ var ErrMissingURLParameter = errors.New("missing URL parameter")
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
// ErrNoAuthSession is returned when an auth_id does not match any active auth session.
var ErrNoAuthSession = errors.New("no auth session found")
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
@@ -113,7 +118,7 @@ func (h *Headscale) NoiseUpgradeHandler(
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
r.Use(middleware.Recoverer)
r.Handle("/metrics", metrics.Handler())
@@ -122,6 +127,9 @@ func (h *Headscale) NoiseUpgradeHandler(
r.Post("/register", ns.RegistrationHandler)
r.Post("/map", ns.PollNetMapHandler)
// SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected.
r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler)
// Not implemented yet
//
// /whoami is a debug endpoint to validate that the client can communicate over the connection,
@@ -153,12 +161,10 @@ func (h *Headscale) NoiseUpgradeHandler(
r.Post("/update-health", ns.NotImplementedHandler)
r.Route("/webclient", func(r chi.Router) {})
r.Post("/c2n", ns.NotImplementedHandler)
})
r.Post("/c2n", ns.NotImplementedHandler)
r.Get("/ssh-action", ns.SSHAction)
ns.httpBaseConfig = &http.Server{
Handler: r,
ReadHeaderTimeout: types.HTTPTimeout,
@@ -249,10 +255,233 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
}
// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction]
// to the client with the verdict of an SSH access request.
func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) {
log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request")
func urlParam[T any](req *http.Request, key string) (T, error) {
var zero T
param := chi.URLParam(req, key)
if param == "" {
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
}
var value T
switch any(value).(type) {
case string:
v, ok := any(param).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
case types.NodeID:
id, err := types.ParseNodeID(param)
if err != nil {
return zero, fmt.Errorf("parsing %s: %w", key, err)
}
v, ok := any(id).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
default:
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
return value, nil
}
// SSHActionHandler handles the /ssh-action endpoint, returning a
// [tailcfg.SSHAction] to the client with the verdict of an SSH access
// request.
func (ns *noiseServer) SSHActionHandler(
writer http.ResponseWriter,
req *http.Request,
) {
srcNodeID, err := urlParam[types.NodeID](req, "src_node_id")
if err != nil {
httpError(writer, NewHTTPError(
http.StatusBadRequest,
"Invalid src_node_id",
err,
))
return
}
dstNodeID, err := urlParam[types.NodeID](req, "dst_node_id")
if err != nil {
httpError(writer, NewHTTPError(
http.StatusBadRequest,
"Invalid dst_node_id",
err,
))
return
}
reqLog := log.With().
Uint64("src_node_id", srcNodeID.Uint64()).
Uint64("dst_node_id", dstNodeID.Uint64()).
Str("ssh_user", req.URL.Query().Get("ssh_user")).
Str("local_user", req.URL.Query().Get("local_user")).
Logger()
reqLog.Trace().Caller().Msg("SSH action request")
action, err := ns.sshAction(
reqLog,
req.URL.Query().Get("auth_id"),
)
if err != nil {
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(action)
if err != nil {
reqLog.Error().Caller().Err(err).
Msg("failed to encode SSH action response")
return
}
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// sshAction resolves the SSH action for the given request parameters.
// It returns the action to send to the client, or an HTTPError on
// failure.
//
// Two cases:
// 1. Initial request — build a HoldAndDelegate URL and wait for the
// user to authenticate.
// 2. Follow-up request — an auth_id is present, wait for the auth
// verdict and accept or reject.
func (ns *noiseServer) sshAction(
reqLog zerolog.Logger,
authIDStr string,
) (*tailcfg.SSHAction, error) {
action := tailcfg.SSHAction{
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
// Follow-up request with auth_id — wait for the auth verdict.
if authIDStr != "" {
return ns.sshActionFollowUp(
reqLog, &action, authIDStr,
)
}
// Initial request — create an auth session and hold.
return ns.sshActionHoldAndDelegate(reqLog, &action)
}
// sshActionHoldAndDelegate creates a new auth session and returns a
// HoldAndDelegate action that directs the client to authenticate.
func (ns *noiseServer) sshActionHoldAndDelegate(
reqLog zerolog.Logger,
action *tailcfg.SSHAction,
) (*tailcfg.SSHAction, error) {
holdURL, err := url.Parse(
ns.headscale.cfg.ServerURL +
"/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID" +
"?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
)
if err != nil {
return nil, NewHTTPError(
http.StatusInternalServerError,
"Internal error",
fmt.Errorf("parsing SSH action URL: %w", err),
)
}
authID, err := types.NewAuthID()
if err != nil {
return nil, NewHTTPError(
http.StatusInternalServerError,
"Internal error",
fmt.Errorf("generating auth ID: %w", err),
)
}
ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest())
authURL := ns.headscale.authProvider.AuthURL(authID)
q := holdURL.Query()
q.Set("auth_id", authID.String())
holdURL.RawQuery = q.Encode()
action.HoldAndDelegate = holdURL.String()
// TODO(kradalby): here we can also send a very tiny mapresponse
// "popping" the url and opening it for the user.
action.Message = fmt.Sprintf(
"# Headscale SSH requires an additional check.\n"+
"# To authenticate, visit: %s\n"+
"# Authentication checked with Headscale SSH.\n",
authURL,
)
reqLog.Info().Caller().
Str("auth_id", authID.String()).
Msg("SSH check pending, waiting for auth")
return action, nil
}
// sshActionFollowUp handles follow-up requests where the client
// provides an auth_id. It blocks until the auth session resolves.
func (ns *noiseServer) sshActionFollowUp(
reqLog zerolog.Logger,
action *tailcfg.SSHAction,
authIDStr string,
) (*tailcfg.SSHAction, error) {
authID, err := types.AuthIDFromString(authIDStr)
if err != nil {
return nil, NewHTTPError(
http.StatusBadRequest,
"Invalid auth_id",
fmt.Errorf("parsing auth_id: %w", err),
)
}
reqLog = reqLog.With().Str("auth_id", authID.String()).Logger()
auth, ok := ns.headscale.state.GetAuthCacheEntry(authID)
if !ok {
return nil, NewHTTPError(
http.StatusBadRequest,
"Invalid auth_id",
fmt.Errorf("%w: %s", ErrNoAuthSession, authID),
)
}
reqLog.Trace().Caller().Msg("SSH action follow-up")
verdict := <-auth.WaitForAuth()
if !verdict.Accept() {
action.Reject = true
reqLog.Trace().Caller().Err(verdict.Err).
Msg("authentication rejected")
return action, nil
}
action.Accept = true
return action, nil
}
// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol
@@ -380,28 +609,3 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.
return nv, nil
}
// urlParam extracts a typed URL parameter from a chi router request.
func urlParam[T any](req *http.Request, key string) (T, error) {
var zero T
param := chi.URLParam(req, key)
if param == "" {
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
}
var value T
switch any(value).(type) {
case string:
v, ok := any(param).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
default:
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
return value, nil
}