This commit is contained in:
Stef Heyenrath
2026-02-12 06:41:06 +01:00
parent f9741af021
commit 3a266c3e18
2 changed files with 13 additions and 15 deletions

View File

@@ -1,11 +1,7 @@
// Copyright © WireMock.Net
using System;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using WireMock.Logging;
using WireMock.RequestBuilders;
using WireMock.ResponseBuilders;

View File

@@ -1,5 +1,6 @@
// Copyright © WireMock.Net
using System.Buffers;
using System.Net;
using System.Net.WebSockets;
using System.Text;
@@ -44,14 +45,14 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
#endif
// Get options from HttpContext.Items (set by WireMockMiddleware)
if (!context.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) ||
if (!context.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) ||
optionsObj is not IWireMockMiddlewareOptions options)
{
throw new InvalidOperationException("WireMockMiddlewareOptions not found in HttpContext.Items");
}
// Get or create registry from options
var registry = _builder.IsBroadcast
var registry = _builder.IsBroadcast
? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry())
: null;
@@ -110,13 +111,13 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
catch (Exception ex)
{
settings.Logger?.Error($"WebSocket error for mapping '{mapping.Guid}': {ex.Message}", ex);
// If we haven't upgraded yet, we can return HTTP error
if (!context.Response.HasStarted)
{
return (ResponseMessageBuilder.Create(HttpStatusCode.InternalServerError, $"WebSocket error: {ex.Message}"), null);
}
// Already upgraded - return marker
return (new WebSocketHandledResponse(), null);
}
@@ -125,9 +126,9 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
private static async Task HandleEchoAsync(WireMockWebSocketContext context)
{
var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize;
var buffer = new byte[bufferSize];
using var buffer = ArrayPool<byte>.Shared.Lease(bufferSize);
var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes);
var cts = new CancellationTokenSource(timeout);
using var cts = new CancellationTokenSource(timeout);
try
{
@@ -172,7 +173,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize;
var buffer = new byte[bufferSize];
var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes);
var cts = new CancellationTokenSource(timeout);
using var cts = new CancellationTokenSource(timeout);
try
{
@@ -210,7 +211,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
private static async Task HandleProxyAsync(WireMockWebSocketContext context, ProxyAndRecordSettings settings)
{
using var clientWebSocket = new ClientWebSocket();
var targetUri = new Uri(settings.Url);
await clientWebSocket.ConnectAsync(targetUri, CancellationToken.None).ConfigureAwait(false);
@@ -219,12 +220,13 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
var serverToClient = ForwardMessagesAsync(clientWebSocket, context.WebSocket);
await Task.WhenAny(clientToServer, serverToClient).ConfigureAwait(false);
// Close both
if (context.WebSocket.State == WebSocketState.Open)
{
await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Proxy closed");
}
if (clientWebSocket.State == WebSocketState.Open)
{
await clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Proxy closed", CancellationToken.None);
@@ -234,7 +236,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
private static async Task ForwardMessagesAsync(WebSocket source, WebSocket destination)
{
var buffer = new byte[WebSocketConstants.ProxyForwardBufferSize];
while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open)
{
var result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
@@ -262,7 +264,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
{
var buffer = new byte[WebSocketConstants.MinimumBufferSize];
var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes);
var cts = new CancellationTokenSource(timeout);
using var cts = new CancellationTokenSource(timeout);
try
{