From f31b1b5ed3a1e50ffadbaa0f091c41c94b57cac7 Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 14 Sep 2025 00:03:27 +0800 Subject: [PATCH] refactor(misc): enhance performance on bytes pool, entrypoint, access log and route context handling - Introduced benchmark tests for Entrypoint and ReverseProxy to evaluate performance. - Updated Entrypoint's ServeHTTP method to improve route context management. - Added new test file for entrypoint benchmarks and refined existing tests for route handling. --- internal/entrypoint/entrypoint.go | 25 +- .../entrypoint/entrypoint_benchmark_test.go | 154 ++++++++++ internal/entrypoint/entrypoint_test.go | 3 + internal/logging/accesslog/access_logger.go | 2 +- internal/logging/accesslog/rotate.go | 2 +- .../reverse_proxy_benchmark_test.go | 49 ++++ internal/route/routes/context.go | 71 ++++- internal/utils/strutils/string.go | 4 +- internal/utils/synk/pool.go | 55 +++- internal/utils/synk/pool_bench_test.go | 28 +- internal/utils/synk/pool_test.go | 263 ++++++++++++++++++ 11 files changed, 623 insertions(+), 33 deletions(-) create mode 100644 internal/entrypoint/entrypoint_benchmark_test.go create mode 100644 internal/net/gphttp/reverseproxy/reverse_proxy_benchmark_test.go create mode 100644 internal/utils/synk/pool_test.go diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 2278d4bc..05a97944 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -13,7 +13,6 @@ import ( "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/types" - "github.com/yusing/go-proxy/internal/utils/strutils" ) type Entrypoint struct { @@ -73,12 +72,13 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { w = accesslog.NewResponseRecorder(w) defer ep.accessLogger.Log(r, w.(*accesslog.ResponseRecorder).Response()) } - mux, err := ep.findRouteFunc(r.Host) + route, err := ep.findRouteFunc(r.Host) if err == nil { + r = routes.WithRouteContext(r, route) if ep.middleware != nil { - ep.middleware.ServeHTTP(mux.ServeHTTP, w, routes.WithRouteContext(r, mux)) + ep.middleware.ServeHTTP(route.ServeHTTP, w, r) } else { - mux.ServeHTTP(w, r) + route.ServeHTTP(w, r) } return } @@ -106,20 +106,23 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func findRouteAnyDomain(host string) (types.HTTPRoute, error) { - hostSplit := strutils.SplitRune(host, '.') - target := hostSplit[0] - - if r, ok := routes.GetHTTPRouteOrExact(target, host); ok { + idx := strings.IndexByte(host, '.') + if idx != -1 { + target := host[:idx] + if r, ok := routes.HTTP.Get(target); ok { + return r, nil + } + } + if r, ok := routes.HTTP.Get(host); ok { return r, nil } - return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, target) + return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host) } func findRouteByDomains(domains []string) func(host string) (types.HTTPRoute, error) { return func(host string) (types.HTTPRoute, error) { for _, domain := range domains { - if strings.HasSuffix(host, domain) { - target := strings.TrimSuffix(host, domain) + if target, ok := strings.CutSuffix(host, domain); ok { if r, ok := routes.HTTP.Get(target); ok { return r, nil } diff --git a/internal/entrypoint/entrypoint_benchmark_test.go b/internal/entrypoint/entrypoint_benchmark_test.go new file mode 100644 index 00000000..0c5c4eab --- /dev/null +++ b/internal/entrypoint/entrypoint_benchmark_test.go @@ -0,0 +1,154 @@ +package entrypoint + +import ( + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/routes" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/types" +) + +type noopResponseWriter struct { + statusCode int + written []byte +} + +func (w *noopResponseWriter) Header() http.Header { + return http.Header{} +} +func (w *noopResponseWriter) Write(b []byte) (int, error) { + w.written = b + return len(b), nil +} +func (w *noopResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} + +type noopTransport struct{} + +func (t noopTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("1")), + Request: req, + Header: http.Header{}, + }, nil +} + +func BenchmarkEntrypointReal(b *testing.B) { + var ep Entrypoint + var req = http.Request{ + Method: "GET", + URL: &url.URL{Path: "/", RawPath: "/"}, + Host: "test.domain.tld", + } + ep.SetFindRouteDomains([]string{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "1") + w.Write([]byte("1")) + })) + defer srv.Close() + + url, err := url.Parse(srv.URL) + if err != nil { + b.Fatal(err) + } + + host, port, err := net.SplitHostPort(url.Host) + if err != nil { + b.Fatal(err) + } + + portInt, err := strconv.Atoi(port) + if err != nil { + b.Fatal(err) + } + + r := &route.Route{ + Alias: "test", + Scheme: "http", + Host: host, + Port: route.Port{Proxy: portInt}, + HealthCheck: &types.HealthCheckConfig{Disable: true}, + } + + err = r.Validate() + if err != nil { + b.Fatal(err) + } + + err = r.Start(task.RootTask("test", false)) + if err != nil { + b.Fatal(err) + } + + var w noopResponseWriter + + b.ResetTimer() + for b.Loop() { + ep.ServeHTTP(&w, &req) + // if w.statusCode != http.StatusOK { + // b.Fatalf("status code is not 200: %d", w.statusCode) + // } + // if string(w.written) != "1" { + // b.Fatalf("written is not 1: %s", string(w.written)) + // } + } +} + +func BenchmarkEntrypoint(b *testing.B) { + var ep Entrypoint + var req = http.Request{ + Method: "GET", + URL: &url.URL{Path: "/", RawPath: "/"}, + Host: "test.domain.tld", + } + ep.SetFindRouteDomains([]string{}) + + r := &route.Route{ + Alias: "test", + Scheme: "http", + Host: "localhost", + Port: route.Port{ + Proxy: 8080, + }, + HealthCheck: &types.HealthCheckConfig{ + Disable: true, + }, + } + + err := r.Validate() + if err != nil { + b.Fatal(err) + } + + err = r.Start(task.RootTask("test", false)) + if err != nil { + b.Fatal(err) + } + + rev, ok := routes.HTTP.Get("test") + if !ok { + b.Fatal("route not found") + } + rev.(types.ReverseProxyRoute).ReverseProxy().Transport = noopTransport{} + + var w noopResponseWriter + + b.ResetTimer() + for b.Loop() { + ep.ServeHTTP(&w, &req) + if w.statusCode != http.StatusOK { + b.Fatalf("status code is not 200: %d", w.statusCode) + } + } +} diff --git a/internal/entrypoint/entrypoint_test.go b/internal/entrypoint/entrypoint_test.go index eb98f442..75db4f18 100644 --- a/internal/entrypoint/entrypoint_test.go +++ b/internal/entrypoint/entrypoint_test.go @@ -15,6 +15,9 @@ func addRoute(alias string) { routes.HTTP.Add(&route.ReveseProxyRoute{ Route: &route.Route{ Alias: alias, + Port: route.Port{ + Proxy: 80, + }, }, }) } diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go index ecae79c4..989f509e 100644 --- a/internal/logging/accesslog/access_logger.go +++ b/internal/logging/accesslog/access_logger.go @@ -76,7 +76,7 @@ const ( errBurst = 5 ) -var lineBufPool = synk.GetBytesPool() +var lineBufPool = synk.GetBytesPoolWithUniqueMemory() func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) { io, err := cfg.IO() diff --git a/internal/logging/accesslog/rotate.go b/internal/logging/accesslog/rotate.go index b472e252..88be872f 100644 --- a/internal/logging/accesslog/rotate.go +++ b/internal/logging/accesslog/rotate.go @@ -66,7 +66,7 @@ type lineInfo struct { Size int64 // Size of this line } -var rotateBytePool = synk.GetBytesPool() +var rotateBytePool = synk.GetBytesPoolWithUniqueMemory() // rotateLogFile rotates the log file based on the retention policy. // It returns the result of the rotation and an error if any. diff --git a/internal/net/gphttp/reverseproxy/reverse_proxy_benchmark_test.go b/internal/net/gphttp/reverseproxy/reverse_proxy_benchmark_test.go new file mode 100644 index 00000000..83091ec3 --- /dev/null +++ b/internal/net/gphttp/reverseproxy/reverse_proxy_benchmark_test.go @@ -0,0 +1,49 @@ +package reverseproxy + +import ( + "io" + "net/http" + "net/url" + "strings" + "testing" + + nettypes "github.com/yusing/go-proxy/internal/net/types" +) + +type noopTransport struct{} + +func (t noopTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("Hello, world!")), + Request: req, + ContentLength: int64(len("Hello, world!")), + Header: http.Header{}, + }, nil +} + +type noopResponseWriter struct{} + +func (w noopResponseWriter) Header() http.Header { + return http.Header{} +} + +func (w noopResponseWriter) Write(b []byte) (int, error) { + return len(b), nil +} + +func (w noopResponseWriter) WriteHeader(statusCode int) { +} + +func BenchmarkReverseProxy(b *testing.B) { + var w noopResponseWriter + var req = http.Request{ + Method: "GET", + URL: &url.URL{Scheme: "http", Host: "test"}, + Body: io.NopCloser(strings.NewReader("Hello, world!")), + } + proxy := NewReverseProxy("test", nettypes.MustParseURL("http://localhost:8080"), noopTransport{}) + for b.Loop() { + proxy.ServeHTTP(w, &req) + } +} diff --git a/internal/route/routes/context.go b/internal/route/routes/context.go index ccddcdf4..811a2cbf 100644 --- a/internal/route/routes/context.go +++ b/internal/route/routes/context.go @@ -2,18 +2,42 @@ package routes import ( "context" + "crypto/tls" + "fmt" + "io" + "mime/multipart" "net/http" "net/url" + "reflect" + "unsafe" "github.com/yusing/go-proxy/internal/types" ) -type RouteContext struct{} +type RouteContextKey struct{} -var routeContextKey = RouteContext{} +type RouteContext struct { + context.Context + Route types.HTTPRoute +} + +var routeContextKey = RouteContextKey{} + +func (r *RouteContext) Value(key any) any { + if key == routeContextKey { + return r.Route + } + return r.Context.Value(key) +} func WithRouteContext(r *http.Request, route types.HTTPRoute) *http.Request { - return r.WithContext(context.WithValue(r.Context(), routeContextKey, route)) + // we don't want to copy the request object every fucking requests + // return r.WithContext(context.WithValue(r.Context(), routeContextKey, route)) + (*requestInternal)(unsafe.Pointer(r)).ctx = &RouteContext{ + Context: r.Context(), + Route: route, + } + return r } func TryGetRoute(r *http.Request) types.HTTPRoute { @@ -74,3 +98,44 @@ func TryGetUpstreamURL(r *http.Request) string { } return "" } + +type requestInternal struct { + Method string + URL *url.URL + Proto string + ProtoMajor int + ProtoMinor int + Header http.Header + Body io.ReadCloser + GetBody func() (io.ReadCloser, error) + ContentLength int64 + TransferEncoding []string + Close bool + Host string + Form url.Values + PostForm url.Values + MultipartForm *multipart.Form + Trailer http.Header + RemoteAddr string + RequestURI string + TLS *tls.ConnectionState + Cancel <-chan struct{} + Response *http.Response + Pattern string + ctx context.Context +} + +func init() { + // make sure ctx has the same offset as http.Request + f, ok := reflect.TypeFor[requestInternal]().FieldByName("ctx") + if !ok { + panic("ctx field not found") + } + f2, ok := reflect.TypeFor[http.Request]().FieldByName("ctx") + if !ok { + panic("ctx field not found") + } + if f.Offset != f2.Offset { + panic(fmt.Sprintf("ctx has different offset than http.Request: %d != %d", f.Offset, f2.Offset)) + } +} diff --git a/internal/utils/strutils/string.go b/internal/utils/strutils/string.go index b0fe354b..8fef6a18 100644 --- a/internal/utils/strutils/string.go +++ b/internal/utils/strutils/string.go @@ -20,8 +20,10 @@ func CommaSeperatedList(s string) []string { return res } +var caseTitle = cases.Title(language.AmericanEnglish) + func Title(s string) string { - return cases.Title(language.AmericanEnglish).String(s) + return caseTitle.String(s) } func ContainsFold(s, substr string) bool { diff --git a/internal/utils/synk/pool.go b/internal/utils/synk/pool.go index 61ab9b1a..30ffaa0a 100644 --- a/internal/utils/synk/pool.go +++ b/internal/utils/synk/pool.go @@ -41,6 +41,28 @@ type BytesPoolWithMemory struct { pool chan weakBuf } +type sliceInternal struct { + ptr unsafe.Pointer + len int + cap int +} + +func sliceStruct(b *[]byte) *sliceInternal { + return (*sliceInternal)(unsafe.Pointer(b)) +} + +func underlyingPtr(b []byte) unsafe.Pointer { + return sliceStruct(&b).ptr +} + +func setCap(b *[]byte, cap int) { + sliceStruct(b).cap = cap +} + +func setLen(b *[]byte, len int) { + sliceStruct(b).len = len +} + const ( kb = 1024 mb = 1024 * kb @@ -88,7 +110,7 @@ func (p *BytesPool) Get() []byte { addReused(cap(bPtr)) return bPtr default: - return make([]byte, 0) + return make([]byte, 0, p.initSize) } } } @@ -113,10 +135,6 @@ func (p *BytesPoolWithMemory) Get() []byte { } func (p *BytesPool) GetSized(size int) []byte { - if size <= SizedPoolThreshold { - addNonPooled(size) - return make([]byte, size) - } for { select { case bWeak := <-p.sizedPool: @@ -125,10 +143,26 @@ func (p *BytesPool) GetSized(size int) []byte { continue } capB := cap(bPtr) - if capB >= size { + + remainingSize := capB - size + if remainingSize == 0 { addReused(capB) - return (bPtr)[:size] + return bPtr[:size] } + + if remainingSize > 0 { // capB > size (buffer larger than requested) + addReused(size) + + p.Put(bPtr[size:capB]) + + // return the first part and limit the capacity to the requested size + ret := bPtr[:size] + setLen(&ret, size) + setCap(&ret, size) + return ret + } + + // size is not enough select { case p.sizedPool <- bWeak: default: @@ -147,11 +181,10 @@ func (p *BytesPool) Put(b []byte) { return } b = b[:0] - w := makeWeak(&b) - if size <= SizedPoolThreshold { - p.put(w, p.unsizedPool) + if size >= SizedPoolThreshold { + p.put(makeWeak(&b), p.sizedPool) } else { - p.put(w, p.sizedPool) + p.put(makeWeak(&b), p.unsizedPool) } } diff --git a/internal/utils/synk/pool_bench_test.go b/internal/utils/synk/pool_bench_test.go index d2f2273d..b43a9ac2 100644 --- a/internal/utils/synk/pool_bench_test.go +++ b/internal/utils/synk/pool_bench_test.go @@ -21,13 +21,25 @@ func BenchmarkBytesPool_MakeSmall(b *testing.B) { func BenchmarkBytesPool_GetLarge(b *testing.B) { for b.Loop() { - bytesPool.Put(bytesPool.GetSized(1024 * 1024)) + buf := bytesPool.GetSized(DropThreshold / 2) + buf[0] = 1 + bytesPool.Put(buf) + } +} + +func BenchmarkBytesPool_GetLargeUnsized(b *testing.B) { + for b.Loop() { + buf := slices.Grow(bytesPool.Get(), DropThreshold/2) + buf = append(buf, 1) + bytesPool.Put(buf) } } func BenchmarkBytesPool_MakeLarge(b *testing.B) { for b.Loop() { - _ = make([]byte, 1024*1024) + buf := make([]byte, DropThreshold/2) + buf[0] = 1 + _ = buf } } @@ -37,10 +49,9 @@ func BenchmarkBytesPool_GetAll(b *testing.B) { } } -func BenchmarkBytesPoolWithMemory(b *testing.B) { - pool := GetBytesPoolWithUniqueMemory() +func BenchmarkBytesPool_GetAllUnsized(b *testing.B) { for i := range b.N { - pool.Put(slices.Grow(pool.Get(), sizes[i%len(sizes)])) + bytesPool.Put(slices.Grow(bytesPool.Get(), sizes[i%len(sizes)])) } } @@ -49,3 +60,10 @@ func BenchmarkBytesPool_MakeAll(b *testing.B) { _ = make([]byte, sizes[i%len(sizes)]) } } + +func BenchmarkBytesPoolWithMemory(b *testing.B) { + pool := GetBytesPoolWithUniqueMemory() + for i := range b.N { + pool.Put(slices.Grow(pool.Get(), sizes[i%len(sizes)])) + } +} diff --git a/internal/utils/synk/pool_test.go b/internal/utils/synk/pool_test.go new file mode 100644 index 00000000..76682506 --- /dev/null +++ b/internal/utils/synk/pool_test.go @@ -0,0 +1,263 @@ +package synk + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSized(t *testing.T) { + b := bytesPool.GetSized(2 * SizedPoolThreshold) + assert.Equal(t, cap(b), 2*SizedPoolThreshold) + bytesPool.Put(b) + assert.Equal(t, underlyingPtr(b), underlyingPtr(bytesPool.GetSized(SizedPoolThreshold))) +} + +func TestUnsized(t *testing.T) { + b := bytesPool.Get() + assert.Equal(t, cap(b), UnsizedAvg) + bytesPool.Put(b) + assert.Equal(t, underlyingPtr(b), underlyingPtr(bytesPool.Get())) +} + +func TestGetSizedExactMatch(t *testing.T) { + // Test exact size match reuse + size := SizedPoolThreshold + b1 := bytesPool.GetSized(size) + assert.Equal(t, size, len(b1)) + assert.Equal(t, size, cap(b1)) + + // Put back into pool + bytesPool.Put(b1) + + // Get same size - should reuse the same buffer + b2 := bytesPool.GetSized(size) + assert.Equal(t, size, len(b2)) + assert.Equal(t, size, cap(b2)) + assert.Equal(t, underlyingPtr(b1), underlyingPtr(b2)) +} + +func TestGetSizedBufferSplit(t *testing.T) { + // Test buffer splitting when capacity > requested size + largeSize := 2 * SizedPoolThreshold + requestedSize := SizedPoolThreshold + + // Create a large buffer and put it in pool + b1 := bytesPool.GetSized(largeSize) + assert.Equal(t, largeSize, len(b1)) + assert.Equal(t, largeSize, cap(b1)) + + bytesPool.Put(b1) + + // Request smaller size - should split the buffer + b2 := bytesPool.GetSized(requestedSize) + assert.Equal(t, requestedSize, len(b2)) + assert.Equal(t, requestedSize, cap(b2)) // capacity should remain the original + assert.Equal(t, underlyingPtr(b1), underlyingPtr(b2)) + + // The remaining part should be put back in pool + // Request the remaining size to verify + remainingSize := largeSize - requestedSize + b3 := bytesPool.GetSized(remainingSize) + assert.Equal(t, remainingSize, len(b3)) + assert.Equal(t, remainingSize, cap(b3)) + + // Verify the remaining buffer points to the correct memory location + originalPtr := underlyingPtr(b1) + remainingPtr := underlyingPtr(b3) + + // The remaining buffer should start at original + requestedSize + expectedOffset := uintptr(originalPtr) + uintptr(requestedSize) + actualOffset := uintptr(remainingPtr) + assert.Equal(t, expectedOffset, actualOffset, "Remaining buffer should point to correct offset") +} + +func TestGetSizedSmallRemainder(t *testing.T) { + // Test when remaining size is smaller than SizedPoolThreshold + poolSize := SizedPoolThreshold + 100 // Just slightly larger than threshold + requestedSize := SizedPoolThreshold + + // Create buffer and put in pool + b1 := bytesPool.GetSized(poolSize) + bytesPool.Put(b1) + + // Request size that leaves small remainder + b2 := bytesPool.GetSized(requestedSize) + assert.Equal(t, requestedSize, len(b2)) + assert.Equal(t, requestedSize, cap(b2)) + + // The small remainder (100 bytes) should NOT be put back in sized pool + // Try to get the remainder size - should create new buffer + b3 := bytesPool.GetSized(100) + assert.Equal(t, 100, len(b3)) + assert.Equal(t, 100, cap(b3)) + assert.NotEqual(t, underlyingPtr(b2), underlyingPtr(b3)) +} + +func TestGetSizedSmallBufferBypass(t *testing.T) { + // Test that small buffers (< SizedPoolThreshold) don't use sized pool + smallSize := SizedPoolThreshold - 1 + + b1 := bytesPool.GetSized(smallSize) + assert.Equal(t, smallSize, len(b1)) + assert.Equal(t, smallSize, cap(b1)) + + b2 := bytesPool.GetSized(smallSize) + assert.Equal(t, smallSize, len(b2)) + assert.Equal(t, smallSize, cap(b2)) + + // Should be different buffers (not pooled) + assert.NotEqual(t, underlyingPtr(b1), underlyingPtr(b2)) +} + +func TestGetSizedBufferTooSmall(t *testing.T) { + // Test when pool buffer is smaller than requested size + smallSize := SizedPoolThreshold + largeSize := 2 * SizedPoolThreshold + + // Put small buffer in pool + b1 := bytesPool.GetSized(smallSize) + bytesPool.Put(b1) + + // Request larger size - should create new buffer, not reuse small one + b2 := bytesPool.GetSized(largeSize) + assert.Equal(t, largeSize, len(b2)) + assert.Equal(t, largeSize, cap(b2)) + assert.NotEqual(t, underlyingPtr(b1), underlyingPtr(b2)) + + // The small buffer should still be in pool + b3 := bytesPool.GetSized(smallSize) + assert.Equal(t, underlyingPtr(b1), underlyingPtr(b3)) +} + +func TestGetSizedMultipleSplits(t *testing.T) { + // Test multiple sequential splits of the same buffer + hugeSize := 4 * SizedPoolThreshold + splitSize := SizedPoolThreshold + + // Create huge buffer + b1 := bytesPool.GetSized(hugeSize) + originalPtr := underlyingPtr(b1) + bytesPool.Put(b1) + + // Split it into smaller pieces + pieces := make([][]byte, 0, 4) + for i := range 4 { + piece := bytesPool.GetSized(splitSize) + pieces = append(pieces, piece) + + // Each piece should point to the correct offset + expectedOffset := uintptr(originalPtr) + uintptr(i*splitSize) + actualOffset := uintptr(underlyingPtr(piece)) + assert.Equal(t, expectedOffset, actualOffset, "Piece %d should point to correct offset", i) + assert.Equal(t, splitSize, len(piece)) + assert.Equal(t, splitSize, cap(piece)) + } + + // All pieces should have the same underlying capacity + for i, piece := range pieces { + assert.Equal(t, splitSize, cap(piece), "Piece %d should have correct capacity", i) + } +} + +func TestGetSizedMemorySafety(t *testing.T) { + // Test that split buffers don't interfere with each other + totalSize := 3 * SizedPoolThreshold + firstSize := SizedPoolThreshold + + // Create buffer and split it + b1 := bytesPool.GetSized(totalSize) + // Fill with test data + for i := range len(b1) { + b1[i] = byte(i % 256) + } + + bytesPool.Put(b1) + + // Get first part + first := bytesPool.GetSized(firstSize) + assert.Equal(t, firstSize, len(first)) + + // Verify data integrity + for i := range len(first) { + assert.Equal(t, byte(i%256), first[i], "Data should be preserved after split") + } + + // Get remaining part + remainingSize := totalSize - firstSize + remaining := bytesPool.GetSized(remainingSize) + assert.Equal(t, remainingSize, len(remaining)) + + // Verify remaining data + for i := range len(remaining) { + expected := byte((i + firstSize) % 256) + assert.Equal(t, expected, remaining[i], "Remaining data should be preserved") + } +} + +func TestGetSizedCapacityLimiting(t *testing.T) { + // Test that returned buffers have limited capacity to prevent overwrites + largeSize := 2 * SizedPoolThreshold + requestedSize := SizedPoolThreshold + + // Create large buffer and put in pool + b1 := bytesPool.GetSized(largeSize) + bytesPool.Put(b1) + + // Get smaller buffer from the split + b2 := bytesPool.GetSized(requestedSize) + assert.Equal(t, requestedSize, len(b2)) + assert.Equal(t, requestedSize, cap(b2), "Returned buffer should have limited capacity") + + // Try to append data - should not be able to overwrite beyond capacity + original := make([]byte, len(b2)) + copy(original, b2) + + // This append should force a new allocation since capacity is limited + b2 = append(b2, 1, 2, 3, 4, 5) + assert.Greater(t, len(b2), requestedSize, "Buffer should have grown") + + // Get the remaining buffer to verify it wasn't affected + remainingSize := largeSize - requestedSize + b3 := bytesPool.GetSized(remainingSize) + assert.Equal(t, remainingSize, len(b3)) + assert.Equal(t, remainingSize, cap(b3), "Remaining buffer should have limited capacity") +} + +func TestGetSizedAppendSafety(t *testing.T) { + // Test that appending to returned buffer doesn't affect remaining buffer + totalSize := 4 * SizedPoolThreshold + firstSize := SizedPoolThreshold + + // Create buffer with specific pattern + b1 := bytesPool.GetSized(totalSize) + for i := range len(b1) { + b1[i] = byte(100 + i%100) + } + bytesPool.Put(b1) + + // Get first part + first := bytesPool.GetSized(firstSize) + assert.Equal(t, firstSize, cap(first), "First part should have limited capacity") + + // Store original first part content + originalFirst := make([]byte, len(first)) + copy(originalFirst, first) + + // Get remaining part to establish its state + remaining := bytesPool.GetSized(SizedPoolThreshold) + + // Store original remaining content + originalRemaining := make([]byte, len(remaining)) + copy(originalRemaining, remaining) + + // Now try to append to first - this should not affect remaining buffers + // since capacity is limited + first = append(first, make([]byte, 1000)...) + + // Verify remaining buffer content is unchanged + for i := range len(originalRemaining) { + assert.Equal(t, originalRemaining[i], remaining[i], + "Remaining buffer should be unaffected by append to first buffer") + } +}