mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-19 07:54:17 +01:00
noise: limit request body size to prevent unauthenticated OOM
The Noise handshake accepts any machine key without checking registration, so all endpoints behind the Noise router are reachable without credentials. Three handlers used io.ReadAll without size limits, allowing an attacker to OOM-kill the server. Fix: - Add http.MaxBytesReader middleware (1 MiB) on the Noise router. - Replace io.ReadAll + json.Unmarshal with json.NewDecoder in PollNetMapHandler and RegistrationHandler. - Stop reading the body in NotImplementedHandler entirely.
This commit is contained in:
@@ -46,6 +46,12 @@ const (
|
||||
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
|
||||
// The early payload is optional. Some servers may not send it... But we do!
|
||||
earlyPayloadMagic = "\xff\xff\xffTS"
|
||||
|
||||
// noiseBodyLimit is the maximum allowed request body size for Noise protocol
|
||||
// handlers. This prevents unauthenticated OOM attacks via unbounded io.ReadAll.
|
||||
// No legitimate Noise request (MapRequest, RegisterRequest, etc.) comes close
|
||||
// to this limit; typical payloads are a few KB.
|
||||
noiseBodyLimit int64 = 1048576 // 1 MiB
|
||||
)
|
||||
|
||||
type noiseServer struct {
|
||||
@@ -110,6 +116,17 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
|
||||
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Limit request body size to prevent unauthenticated OOM attacks.
|
||||
// The Noise handshake accepts any machine key without checking
|
||||
// registration, so all endpoints behind this router are reachable
|
||||
// without credentials.
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Use(metrics.Collector(metrics.CollectorOpts{
|
||||
Host: false,
|
||||
Proto: true,
|
||||
@@ -251,8 +268,7 @@ func rejectUnsupported(
|
||||
}
|
||||
|
||||
func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) {
|
||||
d, _ := io.ReadAll(req.Body)
|
||||
log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit")
|
||||
log.Trace().Caller().Str("path", req.URL.String()).Msg("not implemented handler hit")
|
||||
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
@@ -535,10 +551,10 @@ func (ns *noiseServer) PollNetMapHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
|
||||
var mapRequest tailcfg.MapRequest
|
||||
if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr
|
||||
|
||||
err := json.NewDecoder(req.Body).Decode(&mapRequest)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -584,13 +600,10 @@ func (ns *noiseServer) RegistrationHandler(
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
|
||||
var resp *tailcfg.RegisterResponse
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return &tailcfg.RegisterRequest{}, regErr(err)
|
||||
}
|
||||
|
||||
var regReq tailcfg.RegisterRequest
|
||||
if err := json.Unmarshal(body, ®Req); err != nil { //nolint:noinlineerr
|
||||
|
||||
err := json.NewDecoder(req.Body).Decode(®Req)
|
||||
if err != nil {
|
||||
return ®Req, regErr(err)
|
||||
}
|
||||
|
||||
|
||||
195
hscontrol/noise_test.go
Normal file
195
hscontrol/noise_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package hscontrol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// newNoiseRouterWithBodyLimit builds a chi router with the same body-limit
|
||||
// middleware used in the real Noise router but wired to a test handler that
|
||||
// captures the io.ReadAll result. This lets us verify the limit without
|
||||
// needing a full Headscale instance.
|
||||
func newNoiseRouterWithBodyLimit(readBody *[]byte, readErr *error) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
*readBody, *readErr = io.ReadAll(r.Body)
|
||||
if *readErr != nil {
|
||||
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
r.Post("/machine/map", handler)
|
||||
r.Post("/machine/register", handler)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestNoiseBodyLimit_MapEndpoint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("normal_map_request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var body []byte
|
||||
|
||||
var readErr error
|
||||
|
||||
router := newNoiseRouterWithBodyLimit(&body, &readErr)
|
||||
|
||||
mapReq := tailcfg.MapRequest{Version: 100, Stream: true}
|
||||
payload, err := json.Marshal(mapReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.NoError(t, readErr)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Len(t, body, len(payload))
|
||||
})
|
||||
|
||||
t.Run("oversized_body_rejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var body []byte
|
||||
|
||||
var readErr error
|
||||
|
||||
router := newNoiseRouterWithBodyLimit(&body, &readErr)
|
||||
|
||||
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Error(t, readErr)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
|
||||
assert.LessOrEqual(t, len(body), int(noiseBodyLimit))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoiseBodyLimit_RegisterEndpoint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("normal_register_request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var body []byte
|
||||
|
||||
var readErr error
|
||||
|
||||
router := newNoiseRouterWithBodyLimit(&body, &readErr)
|
||||
|
||||
regReq := tailcfg.RegisterRequest{Version: 100}
|
||||
payload, err := json.Marshal(regReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(payload))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.NoError(t, readErr)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Len(t, body, len(payload))
|
||||
})
|
||||
|
||||
t.Run("oversized_body_rejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var body []byte
|
||||
|
||||
var readErr error
|
||||
|
||||
router := newNoiseRouterWithBodyLimit(&body, &readErr)
|
||||
|
||||
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Error(t, readErr)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code)
|
||||
assert.LessOrEqual(t, len(body), int(noiseBodyLimit))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoiseBodyLimit_AtExactLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var body []byte
|
||||
|
||||
var readErr error
|
||||
|
||||
router := newNoiseRouterWithBodyLimit(&body, &readErr)
|
||||
|
||||
payload := bytes.Repeat([]byte("a"), int(noiseBodyLimit))
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.NoError(t, readErr)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Len(t, body, int(noiseBodyLimit))
|
||||
}
|
||||
|
||||
// TestPollNetMapHandler_OversizedBody calls the real handler with a
|
||||
// MaxBytesReader-wrapped body to verify it fails gracefully (json decode
|
||||
// error on truncated data) rather than consuming unbounded memory.
|
||||
func TestPollNetMapHandler_OversizedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ns := &noiseServer{}
|
||||
|
||||
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized))
|
||||
rec := httptest.NewRecorder()
|
||||
req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit)
|
||||
|
||||
ns.PollNetMapHandler(rec, req)
|
||||
|
||||
// Body is truncated → json.Decode fails → httpError returns 500.
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
// TestRegistrationHandler_OversizedBody calls the real handler with a
|
||||
// MaxBytesReader-wrapped body to verify it returns an error response
|
||||
// rather than consuming unbounded memory.
|
||||
func TestRegistrationHandler_OversizedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ns := &noiseServer{}
|
||||
|
||||
oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1)
|
||||
req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized))
|
||||
rec := httptest.NewRecorder()
|
||||
req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit)
|
||||
|
||||
ns.RegistrationHandler(rec, req)
|
||||
|
||||
// json.Decode returns MaxBytesError → regErr wraps it → handler writes
|
||||
// a RegisterResponse with the error and then rejectUnsupported kicks in
|
||||
// for version 0 → returns 400.
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
Reference in New Issue
Block a user