Files
headscale/hscontrol/noise.go
Kristoffer Dalby 99767cf805 hscontrol: validate machine key and bind src/dst in SSH check handler
SSHActionHandler now verifies that the Noise session's machine key
matches the dst node before proceeding. The (src, dst) pair is
captured at hold-and-delegate time via a new SSHCheckBinding on
AuthRequest so sshActionFollowUp can verify the follow-up URL
matches. The OIDC non-registration callback requires the
authenticated user to own the src node before approving.
2026-04-10 14:09:57 +01:00

747 lines
22 KiB
Go

package hscontrol
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
"github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
// ErrMissingURLParameter is returned when a required URL parameter is not provided.
var ErrMissingURLParameter = errors.New("missing URL parameter")
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
// ErrNoAuthSession is returned when an auth_id does not match any active auth session.
var ErrNoAuthSession = errors.New("no auth session found")
// ErrSSHDstNodeNotFound is returned when the dst node id on a Noise SSH
// action request does not match any registered node.
var ErrSSHDstNodeNotFound = errors.New("ssh action: unknown dst node id")
// ErrSSHMachineKeyMismatch is returned when the Noise session's machine
// key does not match the dst node referenced in the SSH action URL.
var ErrSSHMachineKeyMismatch = errors.New(
"ssh action: noise session machine key does not match dst node",
)
// ErrSSHAuthSessionNotBound is returned when an SSH action follow-up
// references an auth session that is not bound to an SSH check pair.
var ErrSSHAuthSessionNotBound = errors.New(
"ssh action: cached auth session is not an SSH-check binding",
)
// ErrSSHBindingMismatch is returned when an SSH action follow-up's
// (src, dst) pair does not match the cached binding for its auth_id.
var ErrSSHBindingMismatch = errors.New(
"ssh action: cached binding does not match request src/dst",
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
// The first 9 bytes from the server to client over Noise are either an HTTP/2
// settings frame (a normal HTTP/2 setup) or, as Tailscale added later, an "early payload"
// header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
// The early payload is optional. Some servers may not send it... But we do!
earlyPayloadMagic = "\xff\xff\xffTS"
// noiseBodyLimit is the maximum allowed request body size for Noise protocol
// handlers. This prevents unauthenticated OOM attacks via unbounded io.ReadAll.
// No legitimate Noise request (MapRequest, RegisterRequest, etc.) comes close
// to this limit; typical payloads are a few KB.
noiseBodyLimit int64 = 1048576 // 1 MiB
)
type noiseServer struct {
headscale *Headscale
httpBaseConfig *http.Server
http2Server *http2.Server
conn *controlbase.Conn
machineKey key.MachinePublic
nodeKey key.NodePublic
// EarlyNoise-related stuff
challenge key.ChallengePrivate
protocolVersion int
}
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021.
func (h *Headscale) NoiseUpgradeHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().Caller().Msgf("noise upgrade handler for client %s", req.RemoteAddr)
upgrade := req.Header.Get("Upgrade")
if upgrade == "" {
// This probably means that the user is running Headscale behind an
// improperly configured reverse proxy. TS2021 requires WebSockets to
// be passed to Headscale. Let's give them a hint.
log.Warn().
Caller().
Msg("no upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
ns := noiseServer{
headscale: h,
challenge: key.NewChallenge(),
}
noiseConn, err := controlhttpserver.AcceptHTTP(
req.Context(),
writer,
req,
*h.noisePrivateKey,
ns.earlyNoise,
)
if err != nil {
httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
return
}
ns.conn = noiseConn
ns.machineKey = ns.conn.Peer()
ns.protocolVersion = ns.conn.ProtocolVersion()
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
r := chi.NewRouter()
// Limit request body size to prevent unauthenticated OOM attacks.
// The Noise handshake accepts any machine key without checking
// registration, so all endpoints behind this router are reachable
// without credentials.
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit)
next.ServeHTTP(w, r)
})
})
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.RequestLogger(&zerologRequestLogger{}))
r.Use(middleware.Recoverer)
r.Handle("/metrics", metrics.Handler())
r.Route("/machine", func(r chi.Router) {
r.Post("/register", ns.RegistrationHandler)
r.Post("/map", ns.PollNetMapHandler)
// SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected.
r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler)
// Not implemented yet
//
// /whoami is a debug endpoint to validate that the client can communicate over the connection,
// not clear if there is a specific response, it looks like it is just logged.
// https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138
r.Get("/whoami", ns.NotImplementedHandler)
// client sends a [tailcfg.SetDNSRequest] to this endpoints and expect
// the server to create or update this DNS record "somewhere".
// It is typically a TXT record for an ACME challenge.
r.Post("/set-dns", ns.NotImplementedHandler)
// A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes.
// We currently do not support device attributes.
r.Patch("/set-device-attr", ns.NotImplementedHandler)
// A [tailcfg.AuditLogRequest] to send audit log entries to the server.
// The server is expected to store them "somewhere".
// We currently do not support device attributes.
r.Post("/audit-log", ns.NotImplementedHandler)
// handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest].
r.Post("/id-token", ns.NotImplementedHandler)
// Asks the server if a feature is available and receive information about how to enable it.
// Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse].
r.Post("/feature/query", ns.NotImplementedHandler)
r.Post("/update-health", ns.NotImplementedHandler)
r.Route("/webclient", func(r chi.Router) {})
r.Post("/c2n", ns.NotImplementedHandler)
})
ns.httpBaseConfig = &http.Server{
Handler: r,
ReadHeaderTimeout: types.HTTPTimeout,
}
ns.http2Server = &http2.Server{}
ns.http2Server.ServeConn(
noiseConn,
&http2.ServeConnOpts{
BaseConfig: ns.httpBaseConfig,
},
)
}
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
}
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
if !isSupportedVersion(tailcfg.CapabilityVersion(protocolVersion)) {
return unsupportedClientError(tailcfg.CapabilityVersion(protocolVersion))
}
earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{
NodeKeyChallenge: ns.challenge.Public(),
})
if err != nil {
return err
}
// 5 bytes that won't be mistaken for an HTTP/2 frame:
// https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not
// an HTTP/2 settings frame, which isn't of type 'T')
var notH2Frame [5]byte
copy(notH2Frame[:], earlyPayloadMagic)
var lenBuf [4]byte
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) //nolint:gosec // JSON length is bounded
// These writes are all buffered by caller, so fine to do them
// separately:
if _, err := writer.Write(notH2Frame[:]); err != nil { //nolint:noinlineerr
return err
}
if _, err := writer.Write(lenBuf[:]); err != nil { //nolint:noinlineerr
return err
}
if _, err := writer.Write(earlyJSON); err != nil { //nolint:noinlineerr
return err
}
return nil
}
func isSupportedVersion(version tailcfg.CapabilityVersion) bool {
return version >= capver.MinSupportedCapabilityVersion
}
func rejectUnsupported(
writer http.ResponseWriter,
version tailcfg.CapabilityVersion,
mkey key.MachinePublic,
nkey key.NodePublic,
) bool {
// Reject unsupported versions
if !isSupportedVersion(version) {
log.Error().
Caller().
Int("minimum_cap_ver", int(capver.MinSupportedCapabilityVersion)).
Int("client_cap_ver", int(version)).
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
Str("client_version", capver.TailscaleVersion(version)).
Str("node.key", nkey.ShortString()).
Str("machine.key", mkey.ShortString()).
Msg("unsupported client connected")
http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest)
return true
}
return false
}
func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) {
log.Trace().Caller().Str("path", req.URL.String()).Msg("not implemented handler hit")
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
}
func urlParam[T any](req *http.Request, key string) (T, error) {
var zero T
param := chi.URLParam(req, key)
if param == "" {
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
}
var value T
switch any(value).(type) {
case string:
v, ok := any(param).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
case types.NodeID:
id, err := types.ParseNodeID(param)
if err != nil {
return zero, fmt.Errorf("parsing %s: %w", key, err)
}
v, ok := any(id).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
default:
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
return value, nil
}
// SSHActionHandler handles the /ssh-action endpoint, returning a
// [tailcfg.SSHAction] to the client with the verdict of an SSH access
// request.
func (ns *noiseServer) SSHActionHandler(
writer http.ResponseWriter,
req *http.Request,
) {
srcNodeID, err := urlParam[types.NodeID](req, "src_node_id")
if err != nil {
httpError(writer, NewHTTPError(
http.StatusBadRequest,
"Invalid src_node_id",
err,
))
return
}
dstNodeID, err := urlParam[types.NodeID](req, "dst_node_id")
if err != nil {
httpError(writer, NewHTTPError(
http.StatusBadRequest,
"Invalid dst_node_id",
err,
))
return
}
// Authenticate the Noise session: the destination node is the
// tailscaled instance asking us whether to permit an incoming SSH
// connection, so its Noise session must belong to dst. Without this
// check any unauthenticated client could open a Noise tunnel with a
// throwaway machine key and pollute lastSSHAuth for arbitrary
// (src, dst) pairs, defeating SSH check-mode's stolen-key
// protections.
dstNode, ok := ns.headscale.state.GetNodeByID(dstNodeID)
if !ok {
httpError(writer, NewHTTPError(
http.StatusNotFound,
"dst node not found",
fmt.Errorf("%w: %d", ErrSSHDstNodeNotFound, dstNodeID),
))
return
}
if dstNode.MachineKey() != ns.machineKey {
httpError(writer, NewHTTPError(
http.StatusUnauthorized,
"machine key does not match dst node",
fmt.Errorf(
"%w: machine key %s, dst node %d",
ErrSSHMachineKeyMismatch, ns.machineKey.ShortString(), dstNodeID,
),
))
return
}
reqLog := log.With().
Uint64("src_node_id", srcNodeID.Uint64()).
Uint64("dst_node_id", dstNodeID.Uint64()).
Str("ssh_user", req.URL.Query().Get("ssh_user")).
Str("local_user", req.URL.Query().Get("local_user")).
Logger()
reqLog.Trace().Caller().Msg("SSH action request")
action, err := ns.sshAction(
reqLog,
srcNodeID, dstNodeID,
req.URL.Query().Get("auth_id"),
)
if err != nil {
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(action)
if err != nil {
reqLog.Error().Caller().Err(err).
Msg("failed to encode SSH action response")
return
}
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// sshAction resolves the SSH action for the given request parameters.
// It returns the action to send to the client, or an HTTPError on failure.
//
// Three cases:
// 1. Initial request, auto-approved — source recently authenticated
// within the check period, accept immediately.
// 2. Initial request, needs auth — build a HoldAndDelegate URL and
// wait for the user to authenticate.
// 3. Follow-up request — an auth_id is present, wait for the auth
// verdict and accept or reject.
func (ns *noiseServer) sshAction(
reqLog zerolog.Logger,
srcNodeID, dstNodeID types.NodeID,
authIDStr string,
) (*tailcfg.SSHAction, error) {
action := tailcfg.SSHAction{
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
// Look up check params from the server's own policy rather than
// trusting URL parameters, which the client could tamper with.
checkPeriod, checkFound := ns.headscale.state.SSHCheckParams(
srcNodeID, dstNodeID,
)
// Follow-up request with auth_id — wait for the auth verdict.
if authIDStr != "" {
return ns.sshActionFollowUp(
reqLog, &action, authIDStr,
srcNodeID, dstNodeID,
checkFound,
)
}
// Initial request — check if auto-approval applies.
if checkFound && checkPeriod > 0 {
if lastAuth, ok := ns.headscale.state.GetLastSSHAuth(
srcNodeID, dstNodeID,
); ok && time.Since(lastAuth) < checkPeriod {
reqLog.Trace().Caller().
Dur("check_period", checkPeriod).
Time("last_auth", lastAuth).
Msg("auto-approved within check period")
action.Accept = true
return &action, nil
}
}
// No auto-approval — create an auth session and hold.
return ns.sshActionHoldAndDelegate(reqLog, &action, srcNodeID, dstNodeID)
}
// sshActionHoldAndDelegate creates a new auth session bound to the
// (src, dst) pair and returns a HoldAndDelegate action that directs the
// client to authenticate.
func (ns *noiseServer) sshActionHoldAndDelegate(
reqLog zerolog.Logger,
action *tailcfg.SSHAction,
srcNodeID, dstNodeID types.NodeID,
) (*tailcfg.SSHAction, error) {
holdURL, err := url.Parse(
ns.headscale.cfg.ServerURL +
"/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID" +
"?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
)
if err != nil {
return nil, NewHTTPError(
http.StatusInternalServerError,
"Internal error",
fmt.Errorf("parsing SSH action URL: %w", err),
)
}
authID, err := types.NewAuthID()
if err != nil {
return nil, NewHTTPError(
http.StatusInternalServerError,
"Internal error",
fmt.Errorf("generating auth ID: %w", err),
)
}
ns.headscale.state.SetAuthCacheEntry(
authID,
types.NewSSHCheckAuthRequest(srcNodeID, dstNodeID),
)
authURL := ns.headscale.authProvider.AuthURL(authID)
q := holdURL.Query()
q.Set("auth_id", authID.String())
holdURL.RawQuery = q.Encode()
action.HoldAndDelegate = holdURL.String()
// TODO(kradalby): here we can also send a very tiny mapresponse
// "popping" the url and opening it for the user.
action.Message = fmt.Sprintf(
"# Headscale SSH requires an additional check.\n"+
"# To authenticate, visit: %s\n"+
"# Authentication checked with Headscale SSH.\n",
authURL,
)
reqLog.Info().Caller().
Str("auth_id", authID.String()).
Msg("SSH check pending, waiting for auth")
return action, nil
}
// sshActionFollowUp handles follow-up requests where the client
// provides an auth_id. It blocks until the auth session resolves.
func (ns *noiseServer) sshActionFollowUp(
reqLog zerolog.Logger,
action *tailcfg.SSHAction,
authIDStr string,
srcNodeID, dstNodeID types.NodeID,
checkFound bool,
) (*tailcfg.SSHAction, error) {
authID, err := types.AuthIDFromString(authIDStr)
if err != nil {
return nil, NewHTTPError(
http.StatusBadRequest,
"Invalid auth_id",
fmt.Errorf("parsing auth_id: %w", err),
)
}
reqLog = reqLog.With().Str("auth_id", authID.String()).Logger()
auth, ok := ns.headscale.state.GetAuthCacheEntry(authID)
if !ok {
return nil, NewHTTPError(
http.StatusBadRequest,
"Invalid auth_id",
fmt.Errorf("%w: %s", ErrNoAuthSession, authID),
)
}
// Verify the cached binding matches the (src, dst) pair the
// follow-up URL claims. Without this check an attacker who knew an
// auth_id could submit a follow-up for any other (src, dst) pair
// and have its verdict recorded against that pair instead.
if !auth.IsSSHCheck() {
return nil, NewHTTPError(
http.StatusBadRequest,
"auth session is not for SSH check",
fmt.Errorf("%w: %s", ErrSSHAuthSessionNotBound, authID),
)
}
binding := auth.SSHCheckBinding()
if binding.SrcNodeID != srcNodeID || binding.DstNodeID != dstNodeID {
return nil, NewHTTPError(
http.StatusUnauthorized,
"src/dst pair does not match auth session",
fmt.Errorf(
"%w: cached %d->%d, request %d->%d",
ErrSSHBindingMismatch,
binding.SrcNodeID, binding.DstNodeID,
srcNodeID, dstNodeID,
),
)
}
reqLog.Trace().Caller().Msg("SSH action follow-up")
verdict := <-auth.WaitForAuth()
if !verdict.Accept() {
action.Reject = true
reqLog.Trace().Caller().Err(verdict.Err).
Msg("authentication rejected")
return action, nil
}
action.Accept = true
// Record the successful auth for future auto-approval.
if checkFound {
ns.headscale.state.SetLastSSHAuth(srcNodeID, dstNodeID)
reqLog.Trace().Caller().
Msg("auth recorded for auto-approval")
}
return action, nil
}
// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
var mapRequest tailcfg.MapRequest
err := json.NewDecoder(req.Body).Decode(&mapRequest)
if err != nil {
httpError(writer, err)
return
}
// Reject unsupported versions
if rejectUnsupported(writer, mapRequest.Version, ns.machineKey, mapRequest.NodeKey) {
return
}
nv, err := ns.getAndValidateNode(mapRequest)
if err != nil {
httpError(writer, err)
return
}
ns.nodeKey = nv.NodeKey()
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
sess.log.Trace().Caller().Msg("a node sending a MapRequest with Noise protocol")
if !sess.isStreaming() {
sess.serve()
} else {
sess.serveLongPoll()
}
}
func regErr(err error) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{Error: err.Error()}
}
// RegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) RegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {
if req.Method != http.MethodPost {
httpError(writer, errMethodNotAllowed)
return
}
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
var resp *tailcfg.RegisterResponse
var regReq tailcfg.RegisterRequest
err := json.NewDecoder(req.Body).Decode(&regReq)
if err != nil {
return &regReq, regErr(err)
}
ns.nodeKey = regReq.NodeKey
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
if err != nil {
if httpErr, ok := errors.AsType[HTTPError](err); ok {
resp = &tailcfg.RegisterResponse{
Error: httpErr.Msg,
}
return &regReq, resp
}
return &regReq, regErr(err)
}
return &regReq, resp
}()
// Reject unsupported versions
if rejectUnsupported(writer, registerRequest.Version, ns.machineKey, registerRequest.NodeKey) {
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
err := json.NewEncoder(writer).Encode(registerResponse)
if err != nil {
log.Error().Caller().Err(err).Msg("noise registration handler: failed to encode RegisterResponse")
return
}
// Ensure response is flushed to client
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if !ok {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
}
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
if ns.machineKey != nv.MachineKey() {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
}
return nv, nil
}