This commit is contained in:
Stef Heyenrath
2026-02-16 18:41:58 +01:00
parent fa3a33dcc6
commit 5388130009
5 changed files with 90 additions and 103 deletions

View File

@@ -12,6 +12,7 @@ using WireMock.Owin.Mappers;
using WireMock.ResponseBuilders; using WireMock.ResponseBuilders;
using WireMock.Serialization; using WireMock.Serialization;
using WireMock.Settings; using WireMock.Settings;
using WireMock.Util;
namespace WireMock.Owin; namespace WireMock.Owin;
@@ -21,7 +22,8 @@ internal class WireMockMiddleware(
IOwinRequestMapper requestMapper, IOwinRequestMapper requestMapper,
IOwinResponseMapper responseMapper, IOwinResponseMapper responseMapper,
IMappingMatcher mappingMatcher, IMappingMatcher mappingMatcher,
IWireMockMiddlewareLogger logger IWireMockMiddlewareLogger logger,
IGuidUtils guidUtils
) )
{ {
private readonly object _lock = new(); private readonly object _lock = new();
@@ -44,6 +46,7 @@ internal class WireMockMiddleware(
// Store options in HttpContext for providers to access (e.g., WebSocketResponseProvider) // Store options in HttpContext for providers to access (e.g., WebSocketResponseProvider)
ctx.Items[nameof(IWireMockMiddlewareOptions)] = options; ctx.Items[nameof(IWireMockMiddlewareOptions)] = options;
ctx.Items[nameof(IWireMockMiddlewareLogger)] = logger; ctx.Items[nameof(IWireMockMiddlewareLogger)] = logger;
ctx.Items[nameof(IGuidUtils)] = guidUtils;
var request = await requestMapper.MapAsync(ctx, options).ConfigureAwait(false); var request = await requestMapper.MapAsync(ctx, options).ConfigureAwait(false);

View File

@@ -7,9 +7,11 @@ using System.Net.WebSockets;
using System.Text; using System.Text;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using WireMock.Constants; using WireMock.Constants;
using WireMock.Extensions;
using WireMock.Owin; using WireMock.Owin;
using WireMock.Owin.ActivityTracing; using WireMock.Owin.ActivityTracing;
using WireMock.Settings; using WireMock.Settings;
using WireMock.Util;
using WireMock.WebSockets; using WireMock.WebSockets;
namespace WireMock.ResponseProviders; namespace WireMock.ResponseProviders;
@@ -43,20 +45,21 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
var webSocket = await context.WebSockets.AcceptWebSocketAsync(builder.AcceptProtocol).ConfigureAwait(false); var webSocket = await context.WebSockets.AcceptWebSocketAsync(builder.AcceptProtocol).ConfigureAwait(false);
#endif #endif
// Get options from HttpContext.Items (set by WireMockMiddleware) if (!context.Items.TryGetValue<IWireMockMiddlewareOptions>(nameof(IWireMockMiddlewareOptions), out var options))
if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareOptions), out var optionsObj) ||
optionsObj is not IWireMockMiddlewareOptions options)
{ {
throw new InvalidOperationException("IWireMockMiddlewareOptions not found in HttpContext.Items"); throw new InvalidOperationException("IWireMockMiddlewareOptions not found in HttpContext.Items");
} }
// Get logger from HttpContext.Items if (!context.Items.TryGetValue<IWireMockMiddlewareLogger>(nameof(IWireMockMiddlewareLogger), out var logger))
if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareLogger), out var loggerObj) ||
loggerObj is not IWireMockMiddlewareLogger logger)
{ {
throw new InvalidOperationException("IWireMockMiddlewareLogger not found in HttpContext.Items"); throw new InvalidOperationException("IWireMockMiddlewareLogger not found in HttpContext.Items");
} }
if (!context.Items.TryGetValue<IGuidUtils>(nameof(IGuidUtils), out var guidUtils))
{
throw new InvalidOperationException("IGuidUtils not found in HttpContext.Items");
}
// Get or create registry from options // Get or create registry from options
var registry = builder.IsBroadcast var registry = builder.IsBroadcast
? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry()) ? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry())
@@ -71,7 +74,8 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
registry, registry,
builder, builder,
options, options,
logger logger,
guidUtils
); );
// Update scenario state following the same pattern as WireMockMiddleware // Update scenario state following the same pattern as WireMockMiddleware
@@ -141,25 +145,22 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
{ {
while (context.WebSocket.State == WebSocketState.Open && !cts.Token.IsCancellationRequested) while (context.WebSocket.State == WebSocketState.Open && !cts.Token.IsCancellationRequested)
{ {
Activity? receiveActivity = null; Activity? activity = null;
if (shouldTrace) if (shouldTrace)
{ {
receiveActivity = WireMockActivitySource.StartWebSocketMessageActivity(WebSocketMessageDirection.Receive, context.Mapping.Guid); activity = WireMockActivitySource.StartWebSocketMessageActivity(WebSocketMessageDirection.Receive, context.Mapping.Guid);
} }
try try
{ {
var result = await context.WebSocket.ReceiveAsync( var result = await context.WebSocket.ReceiveAsync(new ArraySegment<byte>(buffer), cts.Token).ConfigureAwait(false);
new ArraySegment<byte>(buffer),
cts.Token
).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close) if (result.MessageType == WebSocketMessageType.Close)
{ {
if (shouldTrace) if (shouldTrace)
{ {
WireMockActivitySource.EnrichWithWebSocketMessage( WireMockActivitySource.EnrichWithWebSocketMessage(
receiveActivity, activity,
result.MessageType, result.MessageType,
result.Count, result.Count,
result.EndOfMessage, result.EndOfMessage,
@@ -168,36 +169,29 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
); );
} }
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, receiveActivity); context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, activity);
await context.CloseAsync( await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by client").ConfigureAwait(false);
WebSocketCloseStatus.NormalClosure,
"Closed by client"
).ConfigureAwait(false);
break; break;
} }
// Enrich activity with message details // Enrich activity with message details
string? textContent = null; var data = ToData(result, buffer);
if (result.MessageType == WebSocketMessageType.Text)
{
textContent = Encoding.UTF8.GetString(buffer, 0, result.Count);
}
if (shouldTrace) if (shouldTrace)
{ {
WireMockActivitySource.EnrichWithWebSocketMessage( WireMockActivitySource.EnrichWithWebSocketMessage(
receiveActivity, activity,
result.MessageType, result.MessageType,
result.Count, result.Count,
result.EndOfMessage, result.EndOfMessage,
textContent, data as string,
context.Options?.ActivityTracingOptions context.Options?.ActivityTracingOptions
); );
} }
// Log the receive operation // Log the receive operation
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, textContent, receiveActivity); context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, activity);
// Echo back (this will be logged by context.SendAsync) // Echo back (this will be logged by context.SendAsync)
await context.WebSocket.SendAsync( await context.WebSocket.SendAsync(
@@ -206,15 +200,18 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
result.EndOfMessage, result.EndOfMessage,
cts.Token cts.Token
).ConfigureAwait(false); ).ConfigureAwait(false);
// Log the send (echo) operation
context.LogWebSocketMessage(WebSocketMessageDirection.Send, result.MessageType, data, activity);
} }
catch (Exception ex) catch (Exception ex)
{ {
WireMockActivitySource.RecordException(receiveActivity, ex); WireMockActivitySource.RecordException(activity, ex);
throw; throw;
} }
finally finally
{ {
receiveActivity?.Dispose(); activity?.Dispose();
} }
} }
} }
@@ -232,7 +229,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
Func<WebSocketMessage, IWebSocketContext, Task> handler) Func<WebSocketMessage, IWebSocketContext, Task> handler)
{ {
var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize; var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize;
var buffer = new byte[bufferSize]; using var buffer = ArrayPool<byte>.Shared.Lease(bufferSize);
var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes); var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes);
using var cts = new CancellationTokenSource(timeout); using var cts = new CancellationTokenSource(timeout);
@@ -350,18 +347,12 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
ClientWebSocket clientWebSocket, ClientWebSocket clientWebSocket,
WebSocketMessageDirection direction) WebSocketMessageDirection direction)
{ {
var buffer = new byte[WebSocketConstants.ProxyForwardBufferSize]; using var buffer = ArrayPool<byte>.Shared.Lease(WebSocketConstants.ProxyForwardBufferSize);
// Get options for activity tracing var shouldTrace = context.Options?.ActivityTracingOptions is not null;
var options = context.HttpContext.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) &&
optionsObj is IWireMockMiddlewareOptions wireMockOptions
? wireMockOptions
: null;
var shouldTrace = options?.ActivityTracingOptions is not null; var source = direction == WebSocketMessageDirection.Receive ? context.WebSocket : clientWebSocket;
var destination = direction == WebSocketMessageDirection.Receive ? clientWebSocket : context.WebSocket;
var source = direction == WebSocketMessageDirection.Receive ? context.WebSocket : (WebSocket)clientWebSocket;
var destination = direction == WebSocketMessageDirection.Receive ? (WebSocket)clientWebSocket : context.WebSocket;
while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open) while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open)
{ {
@@ -385,7 +376,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
result.Count, result.Count,
result.EndOfMessage, result.EndOfMessage,
null, null,
options?.ActivityTracingOptions context.Options?.ActivityTracingOptions
); );
} }
@@ -400,16 +391,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
} }
// Enrich activity with message details // Enrich activity with message details
object? data = null; var data = ToData(result, buffer);
if (result.MessageType == WebSocketMessageType.Text)
{
data = Encoding.UTF8.GetString(buffer, 0, result.Count);
}
else if (result.MessageType == WebSocketMessageType.Binary)
{
data = new byte[result.Count];
Array.Copy(buffer, (byte[])data, result.Count);
}
if (shouldTrace) if (shouldTrace)
{ {
@@ -419,7 +401,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
result.Count, result.Count,
result.EndOfMessage, result.EndOfMessage,
data as string, data as string,
options?.ActivityTracingOptions context.Options?.ActivityTracingOptions
); );
} }
@@ -483,17 +465,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
} }
// Log the receive operation // Log the receive operation
object? data = null; var data = ToData(result, buffer);
if (result.MessageType == WebSocketMessageType.Text)
{
data = Encoding.UTF8.GetString(buffer, 0, result.Count);
}
else if (result.MessageType == WebSocketMessageType.Binary)
{
data = new byte[result.Count];
Array.Copy(buffer, (byte[])data, result.Count);
}
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, receiveActivity); context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, receiveActivity);
if (result.MessageType == WebSocketMessageType.Close) if (result.MessageType == WebSocketMessageType.Close)
@@ -543,4 +515,22 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
return message; return message;
} }
private static object? ToData(WebSocketReceiveResult result, byte[] buffer)
{
if (result.MessageType == WebSocketMessageType.Text)
{
return Encoding.UTF8.GetString(buffer, 0, result.Count);
}
if (result.MessageType == WebSocketMessageType.Binary)
{
var data = new byte[result.Count];
Array.Copy(buffer, data, result.Count);
return data;
}
return null;
}
} }

View File

@@ -1,7 +1,5 @@
// Copyright © WireMock.Net // Copyright © WireMock.Net
using System;
namespace WireMock.Util; namespace WireMock.Util;
internal interface IGuidUtils internal interface IGuidUtils

View File

@@ -1,11 +1,9 @@
// Copyright © WireMock.Net // Copyright © WireMock.Net
using System.Diagnostics; using System.Diagnostics;
using System.Net;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text; using System.Text;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Stef.Validation;
using WireMock.Logging; using WireMock.Logging;
using WireMock.Models; using WireMock.Models;
using WireMock.Owin; using WireMock.Owin;
@@ -21,7 +19,7 @@ namespace WireMock.WebSockets;
public class WireMockWebSocketContext : IWebSocketContext public class WireMockWebSocketContext : IWebSocketContext
{ {
/// <inheritdoc /> /// <inheritdoc />
public Guid ConnectionId { get; } = Guid.NewGuid(); public Guid ConnectionId { get; }
/// <inheritdoc /> /// <inheritdoc />
public HttpContext HttpContext { get; } public HttpContext HttpContext { get; }
@@ -54,16 +52,20 @@ public class WireMockWebSocketContext : IWebSocketContext
WebSocketConnectionRegistry? registry, WebSocketConnectionRegistry? registry,
WebSocketBuilder builder, WebSocketBuilder builder,
IWireMockMiddlewareOptions options, IWireMockMiddlewareOptions options,
IWireMockMiddlewareLogger logger) IWireMockMiddlewareLogger logger,
IGuidUtils guidUtils
)
{ {
HttpContext = Guard.NotNull(httpContext); HttpContext = httpContext;
WebSocket = Guard.NotNull(webSocket); WebSocket = webSocket;
RequestMessage = Guard.NotNull(requestMessage); RequestMessage = requestMessage;
Mapping = Guard.NotNull(mapping); Mapping = mapping;
Registry = registry; Registry = registry;
Builder = Guard.NotNull(builder); Builder = builder;
Options = Guard.NotNull(options); Options = options;
Logger = Guard.NotNull(logger); Logger = logger;
ConnectionId = guidUtils.NewGuid();
} }
/// <inheritdoc /> /// <inheritdoc />
@@ -187,7 +189,6 @@ public class WireMockWebSocketContext : IWebSocketContext
object? data, object? data,
Activity? activity) Activity? activity)
{ {
// Create body data
IBodyData bodyData; IBodyData bodyData;
if (messageType == WebSocketMessageType.Text && data is string textContent) if (messageType == WebSocketMessageType.Text && data is string textContent)
{ {
@@ -214,12 +215,11 @@ public class WireMockWebSocketContext : IWebSocketContext
}; };
} }
// Create a pseudo-request or pseudo-response depending on direction var method = $"WS_{direction.ToString().ToUpperInvariant()}";
RequestMessage? requestMessage = null; RequestMessage? requestMessage = null;
IResponseMessage? responseMessage = null; IResponseMessage? responseMessage = null;
var method = $"WS_{direction.ToString().ToUpperInvariant()}";
if (direction == WebSocketMessageDirection.Receive) if (direction == WebSocketMessageDirection.Receive)
{ {
// Received message - log as request // Received message - log as request
@@ -241,7 +241,6 @@ public class WireMockWebSocketContext : IWebSocketContext
responseMessage = new ResponseMessage responseMessage = new ResponseMessage
{ {
Method = method, Method = method,
StatusCode = HttpStatusCode.SwitchingProtocols, // WebSocket status
BodyData = bodyData, BodyData = bodyData,
DateTime = DateTime.UtcNow DateTime = DateTime.UtcNow
}; };

View File

@@ -1,30 +1,25 @@
// Copyright © WireMock.Net // Copyright © WireMock.Net
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Diagnostics;
using System.Linq.Expressions; using System.Linq.Expressions;
using AwesomeAssertions;
using Microsoft.AspNetCore.Http;
using Moq; using Moq;
using WireMock.Models;
using WireMock.Owin;
using WireMock.Owin.Mappers;
using WireMock.Util;
using WireMock.Logging;
using WireMock.Matchers;
using WireMock.Admin.Mappings; using WireMock.Admin.Mappings;
using WireMock.Admin.Requests; using WireMock.Admin.Requests;
using WireMock.Settings;
using AwesomeAssertions;
using WireMock.Handlers; using WireMock.Handlers;
using WireMock.Logging;
using WireMock.Matchers;
using WireMock.Matchers.Request; using WireMock.Matchers.Request;
using WireMock.ResponseBuilders; using WireMock.Models;
using WireMock.RequestBuilders; using WireMock.Owin;
using Microsoft.AspNetCore.Http;
using Microsoft.CodeAnalysis.CSharp.Syntax;
#if NET6_0_OR_GREATER
using WireMock.Owin.ActivityTracing; using WireMock.Owin.ActivityTracing;
using System.Diagnostics; using WireMock.Owin.Mappers;
#endif using WireMock.RequestBuilders;
using WireMock.ResponseBuilders;
using WireMock.Settings;
using WireMock.Util;
namespace WireMock.Net.Tests.Owin; namespace WireMock.Net.Tests.Owin;
@@ -41,13 +36,16 @@ public class WireMockMiddlewareTests
private readonly Mock<IMapping> _mappingMock; private readonly Mock<IMapping> _mappingMock;
private readonly Mock<IRequestMatchResult> _requestMatchResultMock; private readonly Mock<IRequestMatchResult> _requestMatchResultMock;
private readonly Mock<HttpContext> _contextMock; private readonly Mock<HttpContext> _contextMock;
private readonly Mock<IGuidUtils> _guidUtilsMock;
private readonly WireMockMiddleware _sut; private readonly WireMockMiddleware _sut;
public WireMockMiddlewareTests() public WireMockMiddlewareTests()
{ {
var wireMockMiddlewareLoggerMock = new Mock<IWireMockMiddlewareLogger>(); var wireMockMiddlewareLoggerMock = new Mock<IWireMockMiddlewareLogger>();
// wreMockMiddlewareLoggerMock.Setup(g => g.NewGuid()).Returns(NewGuid);
_guidUtilsMock = new Mock<IGuidUtils>();
_guidUtilsMock.Setup(g => g.NewGuid()).Returns(NewGuid);
_optionsMock = new Mock<IWireMockMiddlewareOptions>(); _optionsMock = new Mock<IWireMockMiddlewareOptions>();
_optionsMock.SetupAllProperties(); _optionsMock.SetupAllProperties();
@@ -86,7 +84,8 @@ public class WireMockMiddlewareTests
_requestMapperMock.Object, _requestMapperMock.Object,
_responseMapperMock.Object, _responseMapperMock.Object,
_matcherMock.Object, _matcherMock.Object,
wireMockMiddlewareLoggerMock.Object wireMockMiddlewareLoggerMock.Object,
_guidUtilsMock.Object
); );
} }
@@ -262,7 +261,6 @@ public class WireMockMiddlewareTests
_mappings.Should().HaveCount(1); _mappings.Should().HaveCount(1);
} }
#if NET6_0_OR_GREATER
[Fact] [Fact]
public async Task WireMockMiddleware_Invoke_AdminPath_WithExcludeAdminRequests_ShouldNotStartActivity() public async Task WireMockMiddleware_Invoke_AdminPath_WithExcludeAdminRequests_ShouldNotStartActivity()
{ {
@@ -346,5 +344,4 @@ public class WireMockMiddlewareTests
// Assert // Assert
activityStarted.Should().BeFalse(); activityStarted.Should().BeFalse();
} }
#endif
} }