From 5388130009ac43ca8fcd3f449f5dd0dc5657dc32 Mon Sep 17 00:00:00 2001 From: Stef Heyenrath Date: Mon, 16 Feb 2026 18:41:58 +0100 Subject: [PATCH] . --- .../Owin/WireMockMiddleware.cs | 5 +- .../WebSocketResponseProvider.cs | 116 ++++++++---------- src/WireMock.Net.Minimal/Util/GuidUtils.cs | 2 - .../WebSockets/WireMockWebSocketContext.cs | 31 +++-- .../Owin/WireMockMiddlewareTests.cs | 39 +++--- 5 files changed, 90 insertions(+), 103 deletions(-) diff --git a/src/WireMock.Net.Minimal/Owin/WireMockMiddleware.cs b/src/WireMock.Net.Minimal/Owin/WireMockMiddleware.cs index 20f9c267..5905ae7d 100644 --- a/src/WireMock.Net.Minimal/Owin/WireMockMiddleware.cs +++ b/src/WireMock.Net.Minimal/Owin/WireMockMiddleware.cs @@ -12,6 +12,7 @@ using WireMock.Owin.Mappers; using WireMock.ResponseBuilders; using WireMock.Serialization; using WireMock.Settings; +using WireMock.Util; namespace WireMock.Owin; @@ -21,7 +22,8 @@ internal class WireMockMiddleware( IOwinRequestMapper requestMapper, IOwinResponseMapper responseMapper, IMappingMatcher mappingMatcher, - IWireMockMiddlewareLogger logger + IWireMockMiddlewareLogger logger, + IGuidUtils guidUtils ) { private readonly object _lock = new(); @@ -44,6 +46,7 @@ internal class WireMockMiddleware( // Store options in HttpContext for providers to access (e.g., WebSocketResponseProvider) ctx.Items[nameof(IWireMockMiddlewareOptions)] = options; ctx.Items[nameof(IWireMockMiddlewareLogger)] = logger; + ctx.Items[nameof(IGuidUtils)] = guidUtils; var request = await requestMapper.MapAsync(ctx, options).ConfigureAwait(false); diff --git a/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs b/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs index 2817052e..0089cfe9 100644 --- a/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs +++ b/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs @@ -7,9 +7,11 @@ using System.Net.WebSockets; using System.Text; using Microsoft.AspNetCore.Http; using WireMock.Constants; +using WireMock.Extensions; using WireMock.Owin; using WireMock.Owin.ActivityTracing; using WireMock.Settings; +using WireMock.Util; using WireMock.WebSockets; namespace WireMock.ResponseProviders; @@ -43,20 +45,21 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr var webSocket = await context.WebSockets.AcceptWebSocketAsync(builder.AcceptProtocol).ConfigureAwait(false); #endif - // Get options from HttpContext.Items (set by WireMockMiddleware) - if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareOptions), out var optionsObj) || - optionsObj is not IWireMockMiddlewareOptions options) + if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareOptions), out var options)) { throw new InvalidOperationException("IWireMockMiddlewareOptions not found in HttpContext.Items"); } - // Get logger from HttpContext.Items - if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareLogger), out var loggerObj) || - loggerObj is not IWireMockMiddlewareLogger logger) + if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareLogger), out var logger)) { throw new InvalidOperationException("IWireMockMiddlewareLogger not found in HttpContext.Items"); } + if (!context.Items.TryGetValue(nameof(IGuidUtils), out var guidUtils)) + { + throw new InvalidOperationException("IGuidUtils not found in HttpContext.Items"); + } + // Get or create registry from options var registry = builder.IsBroadcast ? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry()) @@ -71,7 +74,8 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr registry, builder, options, - logger + logger, + guidUtils ); // Update scenario state following the same pattern as WireMockMiddleware @@ -141,25 +145,22 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr { while (context.WebSocket.State == WebSocketState.Open && !cts.Token.IsCancellationRequested) { - Activity? receiveActivity = null; + Activity? activity = null; if (shouldTrace) { - receiveActivity = WireMockActivitySource.StartWebSocketMessageActivity(WebSocketMessageDirection.Receive, context.Mapping.Guid); + activity = WireMockActivitySource.StartWebSocketMessageActivity(WebSocketMessageDirection.Receive, context.Mapping.Guid); } try { - var result = await context.WebSocket.ReceiveAsync( - new ArraySegment(buffer), - cts.Token - ).ConfigureAwait(false); + var result = await context.WebSocket.ReceiveAsync(new ArraySegment(buffer), cts.Token).ConfigureAwait(false); if (result.MessageType == WebSocketMessageType.Close) { if (shouldTrace) { WireMockActivitySource.EnrichWithWebSocketMessage( - receiveActivity, + activity, result.MessageType, result.Count, result.EndOfMessage, @@ -168,36 +169,29 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr ); } - context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, receiveActivity); + context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, activity); - await context.CloseAsync( - WebSocketCloseStatus.NormalClosure, - "Closed by client" - ).ConfigureAwait(false); + await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by client").ConfigureAwait(false); break; } // Enrich activity with message details - string? textContent = null; - if (result.MessageType == WebSocketMessageType.Text) - { - textContent = Encoding.UTF8.GetString(buffer, 0, result.Count); - } + var data = ToData(result, buffer); if (shouldTrace) { WireMockActivitySource.EnrichWithWebSocketMessage( - receiveActivity, + activity, result.MessageType, result.Count, result.EndOfMessage, - textContent, + data as string, context.Options?.ActivityTracingOptions ); } // Log the receive operation - context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, textContent, receiveActivity); + context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, activity); // Echo back (this will be logged by context.SendAsync) await context.WebSocket.SendAsync( @@ -206,15 +200,18 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr result.EndOfMessage, cts.Token ).ConfigureAwait(false); + + // Log the send (echo) operation + context.LogWebSocketMessage(WebSocketMessageDirection.Send, result.MessageType, data, activity); } catch (Exception ex) { - WireMockActivitySource.RecordException(receiveActivity, ex); + WireMockActivitySource.RecordException(activity, ex); throw; } finally { - receiveActivity?.Dispose(); + activity?.Dispose(); } } } @@ -232,7 +229,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr Func handler) { var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize; - var buffer = new byte[bufferSize]; + using var buffer = ArrayPool.Shared.Lease(bufferSize); var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes); using var cts = new CancellationTokenSource(timeout); @@ -350,18 +347,12 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr ClientWebSocket clientWebSocket, WebSocketMessageDirection direction) { - var buffer = new byte[WebSocketConstants.ProxyForwardBufferSize]; + using var buffer = ArrayPool.Shared.Lease(WebSocketConstants.ProxyForwardBufferSize); - // Get options for activity tracing - var options = context.HttpContext.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) && - optionsObj is IWireMockMiddlewareOptions wireMockOptions - ? wireMockOptions - : null; + var shouldTrace = context.Options?.ActivityTracingOptions is not null; - var shouldTrace = options?.ActivityTracingOptions is not null; - - var source = direction == WebSocketMessageDirection.Receive ? context.WebSocket : (WebSocket)clientWebSocket; - var destination = direction == WebSocketMessageDirection.Receive ? (WebSocket)clientWebSocket : context.WebSocket; + var source = direction == WebSocketMessageDirection.Receive ? context.WebSocket : clientWebSocket; + var destination = direction == WebSocketMessageDirection.Receive ? clientWebSocket : context.WebSocket; while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open) { @@ -385,7 +376,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr result.Count, result.EndOfMessage, null, - options?.ActivityTracingOptions + context.Options?.ActivityTracingOptions ); } @@ -400,16 +391,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr } // Enrich activity with message details - object? data = null; - if (result.MessageType == WebSocketMessageType.Text) - { - data = Encoding.UTF8.GetString(buffer, 0, result.Count); - } - else if (result.MessageType == WebSocketMessageType.Binary) - { - data = new byte[result.Count]; - Array.Copy(buffer, (byte[])data, result.Count); - } + var data = ToData(result, buffer); if (shouldTrace) { @@ -419,7 +401,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr result.Count, result.EndOfMessage, data as string, - options?.ActivityTracingOptions + context.Options?.ActivityTracingOptions ); } @@ -483,17 +465,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr } // Log the receive operation - object? data = null; - if (result.MessageType == WebSocketMessageType.Text) - { - data = Encoding.UTF8.GetString(buffer, 0, result.Count); - } - else if (result.MessageType == WebSocketMessageType.Binary) - { - data = new byte[result.Count]; - Array.Copy(buffer, (byte[])data, result.Count); - } - + var data = ToData(result, buffer); context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, receiveActivity); if (result.MessageType == WebSocketMessageType.Close) @@ -543,4 +515,22 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr return message; } + + private static object? ToData(WebSocketReceiveResult result, byte[] buffer) + { + if (result.MessageType == WebSocketMessageType.Text) + { + return Encoding.UTF8.GetString(buffer, 0, result.Count); + } + + if (result.MessageType == WebSocketMessageType.Binary) + { + var data = new byte[result.Count]; + Array.Copy(buffer, data, result.Count); + + return data; + } + + return null; + } } \ No newline at end of file diff --git a/src/WireMock.Net.Minimal/Util/GuidUtils.cs b/src/WireMock.Net.Minimal/Util/GuidUtils.cs index 9f3be53e..5a5cce3d 100644 --- a/src/WireMock.Net.Minimal/Util/GuidUtils.cs +++ b/src/WireMock.Net.Minimal/Util/GuidUtils.cs @@ -1,7 +1,5 @@ // Copyright © WireMock.Net -using System; - namespace WireMock.Util; internal interface IGuidUtils diff --git a/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs b/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs index 511e2ded..08385897 100644 --- a/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs +++ b/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs @@ -1,11 +1,9 @@ // Copyright © WireMock.Net using System.Diagnostics; -using System.Net; using System.Net.WebSockets; using System.Text; using Microsoft.AspNetCore.Http; -using Stef.Validation; using WireMock.Logging; using WireMock.Models; using WireMock.Owin; @@ -21,7 +19,7 @@ namespace WireMock.WebSockets; public class WireMockWebSocketContext : IWebSocketContext { /// - public Guid ConnectionId { get; } = Guid.NewGuid(); + public Guid ConnectionId { get; } /// public HttpContext HttpContext { get; } @@ -54,16 +52,20 @@ public class WireMockWebSocketContext : IWebSocketContext WebSocketConnectionRegistry? registry, WebSocketBuilder builder, IWireMockMiddlewareOptions options, - IWireMockMiddlewareLogger logger) + IWireMockMiddlewareLogger logger, + IGuidUtils guidUtils + ) { - HttpContext = Guard.NotNull(httpContext); - WebSocket = Guard.NotNull(webSocket); - RequestMessage = Guard.NotNull(requestMessage); - Mapping = Guard.NotNull(mapping); + HttpContext = httpContext; + WebSocket = webSocket; + RequestMessage = requestMessage; + Mapping = mapping; Registry = registry; - Builder = Guard.NotNull(builder); - Options = Guard.NotNull(options); - Logger = Guard.NotNull(logger); + Builder = builder; + Options = options; + Logger = logger; + + ConnectionId = guidUtils.NewGuid(); } /// @@ -187,7 +189,6 @@ public class WireMockWebSocketContext : IWebSocketContext object? data, Activity? activity) { - // Create body data IBodyData bodyData; if (messageType == WebSocketMessageType.Text && data is string textContent) { @@ -214,12 +215,11 @@ public class WireMockWebSocketContext : IWebSocketContext }; } - // Create a pseudo-request or pseudo-response depending on direction + var method = $"WS_{direction.ToString().ToUpperInvariant()}"; + RequestMessage? requestMessage = null; IResponseMessage? responseMessage = null; - var method = $"WS_{direction.ToString().ToUpperInvariant()}"; - if (direction == WebSocketMessageDirection.Receive) { // Received message - log as request @@ -241,7 +241,6 @@ public class WireMockWebSocketContext : IWebSocketContext responseMessage = new ResponseMessage { Method = method, - StatusCode = HttpStatusCode.SwitchingProtocols, // WebSocket status BodyData = bodyData, DateTime = DateTime.UtcNow }; diff --git a/test/WireMock.Net.Tests/Owin/WireMockMiddlewareTests.cs b/test/WireMock.Net.Tests/Owin/WireMockMiddlewareTests.cs index 9455a20d..4cc961c4 100644 --- a/test/WireMock.Net.Tests/Owin/WireMockMiddlewareTests.cs +++ b/test/WireMock.Net.Tests/Owin/WireMockMiddlewareTests.cs @@ -1,30 +1,25 @@ // Copyright © WireMock.Net using System.Collections.Concurrent; +using System.Diagnostics; using System.Linq.Expressions; +using AwesomeAssertions; +using Microsoft.AspNetCore.Http; using Moq; -using WireMock.Models; -using WireMock.Owin; -using WireMock.Owin.Mappers; -using WireMock.Util; -using WireMock.Logging; -using WireMock.Matchers; using WireMock.Admin.Mappings; using WireMock.Admin.Requests; -using WireMock.Settings; -using AwesomeAssertions; using WireMock.Handlers; +using WireMock.Logging; +using WireMock.Matchers; using WireMock.Matchers.Request; -using WireMock.ResponseBuilders; -using WireMock.RequestBuilders; -using Microsoft.AspNetCore.Http; -using Microsoft.CodeAnalysis.CSharp.Syntax; - - -#if NET6_0_OR_GREATER +using WireMock.Models; +using WireMock.Owin; using WireMock.Owin.ActivityTracing; -using System.Diagnostics; -#endif +using WireMock.Owin.Mappers; +using WireMock.RequestBuilders; +using WireMock.ResponseBuilders; +using WireMock.Settings; +using WireMock.Util; namespace WireMock.Net.Tests.Owin; @@ -41,13 +36,16 @@ public class WireMockMiddlewareTests private readonly Mock _mappingMock; private readonly Mock _requestMatchResultMock; private readonly Mock _contextMock; + private readonly Mock _guidUtilsMock; private readonly WireMockMiddleware _sut; public WireMockMiddlewareTests() { var wireMockMiddlewareLoggerMock = new Mock(); - // wreMockMiddlewareLoggerMock.Setup(g => g.NewGuid()).Returns(NewGuid); + + _guidUtilsMock = new Mock(); + _guidUtilsMock.Setup(g => g.NewGuid()).Returns(NewGuid); _optionsMock = new Mock(); _optionsMock.SetupAllProperties(); @@ -86,7 +84,8 @@ public class WireMockMiddlewareTests _requestMapperMock.Object, _responseMapperMock.Object, _matcherMock.Object, - wireMockMiddlewareLoggerMock.Object + wireMockMiddlewareLoggerMock.Object, + _guidUtilsMock.Object ); } @@ -262,7 +261,6 @@ public class WireMockMiddlewareTests _mappings.Should().HaveCount(1); } -#if NET6_0_OR_GREATER [Fact] public async Task WireMockMiddleware_Invoke_AdminPath_WithExcludeAdminRequests_ShouldNotStartActivity() { @@ -346,5 +344,4 @@ public class WireMockMiddlewareTests // Assert activityStarted.Should().BeFalse(); } -#endif } \ No newline at end of file