diff --git a/hscontrol/noise.go b/hscontrol/noise.go index cd2f2036..51020f9f 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -46,6 +46,12 @@ const ( // of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise. // The early payload is optional. Some servers may not send it... But we do! earlyPayloadMagic = "\xff\xff\xffTS" + + // noiseBodyLimit is the maximum allowed request body size for Noise protocol + // handlers. This prevents unauthenticated OOM attacks via unbounded io.ReadAll. + // No legitimate Noise request (MapRequest, RegisterRequest, etc.) comes close + // to this limit; typical payloads are a few KB. + noiseBodyLimit int64 = 1048576 // 1 MiB ) type noiseServer struct { @@ -110,6 +116,17 @@ func (h *Headscale) NoiseUpgradeHandler( // a single hijacked connection from /ts2021, using netutil.NewOneConnListener r := chi.NewRouter() + + // Limit request body size to prevent unauthenticated OOM attacks. + // The Noise handshake accepts any machine key without checking + // registration, so all endpoints behind this router are reachable + // without credentials. + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit) + next.ServeHTTP(w, r) + }) + }) r.Use(metrics.Collector(metrics.CollectorOpts{ Host: false, Proto: true, @@ -251,8 +268,7 @@ func rejectUnsupported( } func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) { - d, _ := io.ReadAll(req.Body) - log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit") + log.Trace().Caller().Str("path", req.URL.String()).Msg("not implemented handler hit") http.Error(writer, "Not implemented yet", http.StatusNotImplemented) } @@ -535,10 +551,10 @@ func (ns *noiseServer) PollNetMapHandler( writer http.ResponseWriter, req *http.Request, ) { - body, _ := io.ReadAll(req.Body) - var mapRequest tailcfg.MapRequest - if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr + + err := json.NewDecoder(req.Body).Decode(&mapRequest) + if err != nil { httpError(writer, err) return } @@ -584,13 +600,10 @@ func (ns *noiseServer) RegistrationHandler( registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck var resp *tailcfg.RegisterResponse - body, err := io.ReadAll(req.Body) - if err != nil { - return &tailcfg.RegisterRequest{}, regErr(err) - } - var regReq tailcfg.RegisterRequest - if err := json.Unmarshal(body, ®Req); err != nil { //nolint:noinlineerr + + err := json.NewDecoder(req.Body).Decode(®Req) + if err != nil { return ®Req, regErr(err) } diff --git a/hscontrol/noise_test.go b/hscontrol/noise_test.go new file mode 100644 index 00000000..594521f5 --- /dev/null +++ b/hscontrol/noise_test.go @@ -0,0 +1,195 @@ +package hscontrol + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +// newNoiseRouterWithBodyLimit builds a chi router with the same body-limit +// middleware used in the real Noise router but wired to a test handler that +// captures the io.ReadAll result. This lets us verify the limit without +// needing a full Headscale instance. +func newNoiseRouterWithBodyLimit(readBody *[]byte, readErr *error) http.Handler { + r := chi.NewRouter() + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, noiseBodyLimit) + next.ServeHTTP(w, r) + }) + }) + + handler := func(w http.ResponseWriter, r *http.Request) { + *readBody, *readErr = io.ReadAll(r.Body) + if *readErr != nil { + http.Error(w, "body too large", http.StatusRequestEntityTooLarge) + + return + } + + w.WriteHeader(http.StatusOK) + } + + r.Post("/machine/map", handler) + r.Post("/machine/register", handler) + + return r +} + +func TestNoiseBodyLimit_MapEndpoint(t *testing.T) { + t.Parallel() + + t.Run("normal_map_request", func(t *testing.T) { + t.Parallel() + + var body []byte + + var readErr error + + router := newNoiseRouterWithBodyLimit(&body, &readErr) + + mapReq := tailcfg.MapRequest{Version: 100, Stream: true} + payload, err := json.Marshal(mapReq) + require.NoError(t, err) + + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.NoError(t, readErr) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Len(t, body, len(payload)) + }) + + t.Run("oversized_body_rejected", func(t *testing.T) { + t.Parallel() + + var body []byte + + var readErr error + + router := newNoiseRouterWithBodyLimit(&body, &readErr) + + oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Error(t, readErr) + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + assert.LessOrEqual(t, len(body), int(noiseBodyLimit)) + }) +} + +func TestNoiseBodyLimit_RegisterEndpoint(t *testing.T) { + t.Parallel() + + t.Run("normal_register_request", func(t *testing.T) { + t.Parallel() + + var body []byte + + var readErr error + + router := newNoiseRouterWithBodyLimit(&body, &readErr) + + regReq := tailcfg.RegisterRequest{Version: 100} + payload, err := json.Marshal(regReq) + require.NoError(t, err) + + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(payload)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.NoError(t, readErr) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Len(t, body, len(payload)) + }) + + t.Run("oversized_body_rejected", func(t *testing.T) { + t.Parallel() + + var body []byte + + var readErr error + + router := newNoiseRouterWithBodyLimit(&body, &readErr) + + oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Error(t, readErr) + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + assert.LessOrEqual(t, len(body), int(noiseBodyLimit)) + }) +} + +func TestNoiseBodyLimit_AtExactLimit(t *testing.T) { + t.Parallel() + + var body []byte + + var readErr error + + router := newNoiseRouterWithBodyLimit(&body, &readErr) + + payload := bytes.Repeat([]byte("a"), int(noiseBodyLimit)) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(payload)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.NoError(t, readErr) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Len(t, body, int(noiseBodyLimit)) +} + +// TestPollNetMapHandler_OversizedBody calls the real handler with a +// MaxBytesReader-wrapped body to verify it fails gracefully (json decode +// error on truncated data) rather than consuming unbounded memory. +func TestPollNetMapHandler_OversizedBody(t *testing.T) { + t.Parallel() + + ns := &noiseServer{} + + oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/map", bytes.NewReader(oversized)) + rec := httptest.NewRecorder() + req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit) + + ns.PollNetMapHandler(rec, req) + + // Body is truncated → json.Decode fails → httpError returns 500. + assert.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// TestRegistrationHandler_OversizedBody calls the real handler with a +// MaxBytesReader-wrapped body to verify it returns an error response +// rather than consuming unbounded memory. +func TestRegistrationHandler_OversizedBody(t *testing.T) { + t.Parallel() + + ns := &noiseServer{} + + oversized := bytes.Repeat([]byte("x"), int(noiseBodyLimit)+1) + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/machine/register", bytes.NewReader(oversized)) + rec := httptest.NewRecorder() + req.Body = http.MaxBytesReader(rec, req.Body, noiseBodyLimit) + + ns.RegistrationHandler(rec, req) + + // json.Decode returns MaxBytesError → regErr wraps it → handler writes + // a RegisterResponse with the error and then rejectUnsupported kicks in + // for version 0 → returns 400. + assert.Equal(t, http.StatusBadRequest, rec.Code) +}