diff --git a/examples/WireMock.Net.WebSocketExamples/Program.cs b/examples/WireMock.Net.WebSocketExamples/Program.cs index 04b0bb1d..74083e9f 100644 --- a/examples/WireMock.Net.WebSocketExamples/Program.cs +++ b/examples/WireMock.Net.WebSocketExamples/Program.cs @@ -1,11 +1,7 @@ // Copyright © WireMock.Net -using System; -using System.Linq; using System.Net.WebSockets; using System.Text; -using System.Threading; -using System.Threading.Tasks; using WireMock.Logging; using WireMock.RequestBuilders; using WireMock.ResponseBuilders; diff --git a/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs b/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs index 18b1d2ec..ba332e00 100644 --- a/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs +++ b/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs @@ -1,5 +1,6 @@ // Copyright © WireMock.Net +using System.Buffers; using System.Net; using System.Net.WebSockets; using System.Text; @@ -44,14 +45,14 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr #endif // Get options from HttpContext.Items (set by WireMockMiddleware) - if (!context.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) || + if (!context.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) || optionsObj is not IWireMockMiddlewareOptions options) { throw new InvalidOperationException("WireMockMiddlewareOptions not found in HttpContext.Items"); } // Get or create registry from options - var registry = _builder.IsBroadcast + var registry = _builder.IsBroadcast ? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry()) : null; @@ -110,13 +111,13 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr catch (Exception ex) { settings.Logger?.Error($"WebSocket error for mapping '{mapping.Guid}': {ex.Message}", ex); - + // If we haven't upgraded yet, we can return HTTP error if (!context.Response.HasStarted) { return (ResponseMessageBuilder.Create(HttpStatusCode.InternalServerError, $"WebSocket error: {ex.Message}"), null); } - + // Already upgraded - return marker return (new WebSocketHandledResponse(), null); } @@ -125,9 +126,9 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr private static async Task HandleEchoAsync(WireMockWebSocketContext context) { var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize; - var buffer = new byte[bufferSize]; + using var buffer = ArrayPool.Shared.Lease(bufferSize); var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes); - var cts = new CancellationTokenSource(timeout); + using var cts = new CancellationTokenSource(timeout); try { @@ -172,7 +173,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize; var buffer = new byte[bufferSize]; var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes); - var cts = new CancellationTokenSource(timeout); + using var cts = new CancellationTokenSource(timeout); try { @@ -210,7 +211,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr private static async Task HandleProxyAsync(WireMockWebSocketContext context, ProxyAndRecordSettings settings) { using var clientWebSocket = new ClientWebSocket(); - + var targetUri = new Uri(settings.Url); await clientWebSocket.ConnectAsync(targetUri, CancellationToken.None).ConfigureAwait(false); @@ -219,12 +220,13 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr var serverToClient = ForwardMessagesAsync(clientWebSocket, context.WebSocket); await Task.WhenAny(clientToServer, serverToClient).ConfigureAwait(false); - + // Close both if (context.WebSocket.State == WebSocketState.Open) { await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Proxy closed"); } + if (clientWebSocket.State == WebSocketState.Open) { await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Proxy closed", CancellationToken.None); @@ -234,7 +236,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr private static async Task ForwardMessagesAsync(WebSocket source, WebSocket destination) { var buffer = new byte[WebSocketConstants.ProxyForwardBufferSize]; - + while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open) { var result = await source.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); @@ -262,7 +264,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr { var buffer = new byte[WebSocketConstants.MinimumBufferSize]; var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes); - var cts = new CancellationTokenSource(timeout); + using var cts = new CancellationTokenSource(timeout); try {