Files
headscale/hscontrol/noise_test.go
Juan Font 4d427cfe2a noise: limit request body size to prevent unauthenticated OOM
The Noise handshake accepts any machine key without checking
registration, so all endpoints behind the Noise router are reachable
without credentials. Three handlers used io.ReadAll without size
limits, allowing an attacker to OOM-kill the server.

Fix:
- Add http.MaxBytesReader middleware (1 MiB) on the Noise router.
- Replace io.ReadAll + json.Unmarshal with json.NewDecoder in
  PollNetMapHandler and RegistrationHandler.
- Stop reading the body in NotImplementedHandler entirely.
2026-03-16 09:28:31 +01:00

196 lines
5.6 KiB
Go

package hscontrol
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
// newNoiseRouterWithBodyLimit builds a chi router with the same body-limit
// middleware used in the real Noise router but wired to a test handler that
// captures the io.ReadAll result. This lets us verify the limit without
// needing a full Headscale instance.
func newNoiseRouterWithBodyLimit(readBody *[]byte, readErr *error) http.Handler {
r := chi.NewRouter()
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit)
next.ServeHTTP(w, r)
})
})
handler := func(w http.ResponseWriter, r *http.Request) {
*readBody, *readErr = io.ReadAll(r.Body)
if *readErr != nil {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}
r.Post("/machine/map", handler)
r.Post("/machine/register", handler)
return r
}
func TestNoiseBodyLimit_MapEndpoint(t *testing.T) {
t.Parallel()
t.Run("normal_map_request", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
mapReq := tailcfg.MapRequest{Version: 100, Stream: true}
payload, err := json.Marshal(mapReq)
require.NoError(t, err)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.NoError(t, readErr)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Len(t, body, len(payload))
})
t.Run("oversized_body_rejected", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Error(t, readErr)
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
assert.LessOrEqual(t, len(body), int(noiseBodyLimit))
})
}
func TestNoiseBodyLimit_RegisterEndpoint(t *testing.T) {
t.Parallel()
t.Run("normal_register_request", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
regReq := tailcfg.RegisterRequest{Version: 100}
payload, err := json.Marshal(regReq)
require.NoError(t, err)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(payload))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.NoError(t, readErr)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Len(t, body, len(payload))
})
t.Run("oversized_body_rejected", func(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Error(t, readErr)
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
assert.LessOrEqual(t, len(body), int(noiseBodyLimit))
})
}
func TestNoiseBodyLimit_AtExactLimit(t *testing.T) {
t.Parallel()
var body []byte
var readErr error
router := newNoiseRouterWithBodyLimit(&body, &readErr)
payload := bytes.Repeat([]byte("a"), int(noiseBodyLimit))
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.NoError(t, readErr)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Len(t, body, int(noiseBodyLimit))
}
// TestPollNetMapHandler_OversizedBody calls the real handler with a
// MaxBytesReader-wrapped body to verify it fails gracefully (json decode
// error on truncated data) rather than consuming unbounded memory.
func TestPollNetMapHandler_OversizedBody(t *testing.T) {
t.Parallel()
ns := &noiseServer{}
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit)
ns.PollNetMapHandler(rec, req)
// Body is truncated → json.Decode fails → httpError returns 500.
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
// TestRegistrationHandler_OversizedBody calls the real handler with a
// MaxBytesReader-wrapped body to verify it returns an error response
// rather than consuming unbounded memory.
func TestRegistrationHandler_OversizedBody(t *testing.T) {
t.Parallel()
ns := &noiseServer{}
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized))
rec := httptest.NewRecorder()
req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit)
ns.RegistrationHandler(rec, req)
// json.Decode returns MaxBytesError → regErr wraps it → handler writes
// a RegisterResponse with the error and then rejectUnsupported kicks in
// for version 0 → returns 400.
assert.Equal(t, http.StatusBadRequest, rec.Code)
}