noise: limit request body size to prevent unauthenticated OOM

The Noise handshake accepts any machine key without checking
registration, so all endpoints behind the Noise router are reachable
without credentials. Three handlers used io.ReadAll without size
limits, allowing an attacker to OOM-kill the server.

Fix:
- Add http.MaxBytesReader middleware (1 MiB) on the Noise router.
- Replace io.ReadAll + json.Unmarshal with json.NewDecoder in
  PollNetMapHandler and RegistrationHandler.
- Stop reading the body in NotImplementedHandler entirely.
This commit is contained in:
Juan Font
2026-03-15 21:48:03 +01:00
parent afd3a6acbc
commit 4d427cfe2a
2 changed files with 219 additions and 11 deletions

View File

@@ -46,6 +46,12 @@ const (
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
// The early payload is optional. Some servers may not send it... But we do!
earlyPayloadMagic = "\xff\xff\xffTS"
// noiseBodyLimit is the maximum allowed request body size for Noise protocol
// handlers. This prevents unauthenticated OOM attacks via unbounded io.ReadAll.
// No legitimate Noise request (MapRequest, RegisterRequest, etc.) comes close
// to this limit; typical payloads are a few KB.
noiseBodyLimit int64 = 1048576 // 1 MiB
)
type noiseServer struct {
@@ -110,6 +116,17 @@ func (h *Headscale) NoiseUpgradeHandler(
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
r := chi.NewRouter()
// Limit request body size to prevent unauthenticated OOM attacks.
// The Noise handshake accepts any machine key without checking
// registration, so all endpoints behind this router are reachable
// without credentials.
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit)
next.ServeHTTP(w, r)
})
})
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
@@ -251,8 +268,7 @@ func rejectUnsupported(
}
func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) {
d, _ := io.ReadAll(req.Body)
log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit")
log.Trace().Caller().Str("path", req.URL.String()).Msg("not implemented handler hit")
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
}
@@ -535,10 +551,10 @@ func (ns *noiseServer) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
body, _ := io.ReadAll(req.Body)
var mapRequest tailcfg.MapRequest
if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr
err := json.NewDecoder(req.Body).Decode(&mapRequest)
if err != nil {
httpError(writer, err)
return
}
@@ -584,13 +600,10 @@ func (ns *noiseServer) RegistrationHandler(
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
var resp *tailcfg.RegisterResponse
body, err := io.ReadAll(req.Body)
if err != nil {
return &tailcfg.RegisterRequest{}, regErr(err)
}
var regReq tailcfg.RegisterRequest
if err := json.Unmarshal(body, &regReq); err != nil { //nolint:noinlineerr
err := json.NewDecoder(req.Body).Decode(&regReq)
if err != nil {
return &regReq, regErr(err)
}

195
hscontrol/noise_test.go Normal file
View File

@@ -0,0 +1,195 @@
package hscontrol
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
// newNoiseRouterWithBodyLimit builds a chi router with the same body-limit
// middleware used in the real Noise router but wired to a test handler that
// captures the io.ReadAll result. This lets us verify the limit without
// needing a full Headscale instance.
func newNoiseRouterWithBodyLimit(readBody *[]byte, readErr *error) http.Handler {
r := chi.NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit)
next.ServeHTTP(w, r)
})
})
handler := func(w http.ResponseWriter, r *http.Request) {
*readBody, *readErr = io.ReadAll(r.Body)
if *readErr != nil {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}
r.Post("/machine/map", handler)
r.Post("/machine/register", handler)
return r
}
func TestNoiseBodyLimit_MapEndpoint(t *testing.T) {
t.Parallel()
t.Run("normal_map_request", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
mapReq := tailcfg.MapRequest{Version: 100, Stream: true}
payload, err := json.Marshal(mapReq)
require.NoError(t, err)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.NoError(t, readErr)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Len(t, body, len(payload))
})
t.Run("oversized_body_rejected", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Error(t, readErr)
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
assert.LessOrEqual(t, len(body), int(noiseBodyLimit))
})
}
func TestNoiseBodyLimit_RegisterEndpoint(t *testing.T) {
t.Parallel()
t.Run("normal_register_request", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
regReq := tailcfg.RegisterRequest{Version: 100}
payload, err := json.Marshal(regReq)
require.NoError(t, err)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(payload))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.NoError(t, readErr)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Len(t, body, len(payload))
})
t.Run("oversized_body_rejected", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Error(t, readErr)
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
assert.LessOrEqual(t, len(body), int(noiseBodyLimit))
})
}
func TestNoiseBodyLimit_AtExactLimit(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
payload := bytes.Repeat([]byte("a"), int(noiseBodyLimit))
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.NoError(t, readErr)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Len(t, body, int(noiseBodyLimit))
}
// TestPollNetMapHandler_OversizedBody calls the real handler with a
// MaxBytesReader-wrapped body to verify it fails gracefully (json decode
// error on truncated data) rather than consuming unbounded memory.
func TestPollNetMapHandler_OversizedBody(t *testing.T) {
t.Parallel()
ns := &noiseServer{}
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit)
ns.PollNetMapHandler(rec, req)
// Body is truncated → json.Decode fails → httpError returns 500.
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
// TestRegistrationHandler_OversizedBody calls the real handler with a
// MaxBytesReader-wrapped body to verify it returns an error response
// rather than consuming unbounded memory.
func TestRegistrationHandler_OversizedBody(t *testing.T) {
t.Parallel()
ns := &noiseServer{}
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit)
ns.RegistrationHandler(rec, req)
// json.Decode returns MaxBytesError → regErr wraps it → handler writes
// a RegisterResponse with the error and then rejectUnsupported kicks in
// for version 0 → returns 400.
assert.Equal(t, http.StatusBadRequest, rec.Code)
}