This commit is contained in:
Stef Heyenrath
2026-02-16 18:41:58 +01:00
parent fa3a33dcc6
commit 5388130009
5 changed files with 90 additions and 103 deletions

View File

@@ -12,6 +12,7 @@ using WireMock.Owin.Mappers;
using WireMock.ResponseBuilders;
using WireMock.Serialization;
using WireMock.Settings;
using WireMock.Util;
namespace WireMock.Owin;
@@ -21,7 +22,8 @@ internal class WireMockMiddleware(
IOwinRequestMapper requestMapper,
IOwinResponseMapper responseMapper,
IMappingMatcher mappingMatcher,
IWireMockMiddlewareLogger logger
IWireMockMiddlewareLogger logger,
IGuidUtils guidUtils
)
{
private readonly object _lock = new();
@@ -44,6 +46,7 @@ internal class WireMockMiddleware(
// Store options in HttpContext for providers to access (e.g., WebSocketResponseProvider)
ctx.Items[nameof(IWireMockMiddlewareOptions)] = options;
ctx.Items[nameof(IWireMockMiddlewareLogger)] = logger;
ctx.Items[nameof(IGuidUtils)] = guidUtils;
var request = await requestMapper.MapAsync(ctx, options).ConfigureAwait(false);

View File

@@ -7,9 +7,11 @@ using System.Net.WebSockets;
using System.Text;
using Microsoft.AspNetCore.Http;
using WireMock.Constants;
using WireMock.Extensions;
using WireMock.Owin;
using WireMock.Owin.ActivityTracing;
using WireMock.Settings;
using WireMock.Util;
using WireMock.WebSockets;
namespace WireMock.ResponseProviders;
@@ -43,20 +45,21 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
var webSocket = await context.WebSockets.AcceptWebSocketAsync(builder.AcceptProtocol).ConfigureAwait(false);
#endif
// Get options from HttpContext.Items (set by WireMockMiddleware)
if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareOptions), out var optionsObj) ||
optionsObj is not IWireMockMiddlewareOptions options)
if (!context.Items.TryGetValue<IWireMockMiddlewareOptions>(nameof(IWireMockMiddlewareOptions), out var options))
{
throw new InvalidOperationException("IWireMockMiddlewareOptions not found in HttpContext.Items");
}
// Get logger from HttpContext.Items
if (!context.Items.TryGetValue(nameof(IWireMockMiddlewareLogger), out var loggerObj) ||
loggerObj is not IWireMockMiddlewareLogger logger)
if (!context.Items.TryGetValue<IWireMockMiddlewareLogger>(nameof(IWireMockMiddlewareLogger), out var logger))
{
throw new InvalidOperationException("IWireMockMiddlewareLogger not found in HttpContext.Items");
}
if (!context.Items.TryGetValue<IGuidUtils>(nameof(IGuidUtils), out var guidUtils))
{
throw new InvalidOperationException("IGuidUtils not found in HttpContext.Items");
}
// Get or create registry from options
var registry = builder.IsBroadcast
? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry())
@@ -71,7 +74,8 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
registry,
builder,
options,
logger
logger,
guidUtils
);
// Update scenario state following the same pattern as WireMockMiddleware
@@ -141,25 +145,22 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
{
while (context.WebSocket.State == WebSocketState.Open && !cts.Token.IsCancellationRequested)
{
Activity? receiveActivity = null;
Activity? activity = null;
if (shouldTrace)
{
receiveActivity = WireMockActivitySource.StartWebSocketMessageActivity(WebSocketMessageDirection.Receive, context.Mapping.Guid);
activity = WireMockActivitySource.StartWebSocketMessageActivity(WebSocketMessageDirection.Receive, context.Mapping.Guid);
}
try
{
var result = await context.WebSocket.ReceiveAsync(
new ArraySegment<byte>(buffer),
cts.Token
).ConfigureAwait(false);
var result = await context.WebSocket.ReceiveAsync(new ArraySegment<byte>(buffer), cts.Token).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close)
{
if (shouldTrace)
{
WireMockActivitySource.EnrichWithWebSocketMessage(
receiveActivity,
activity,
result.MessageType,
result.Count,
result.EndOfMessage,
@@ -168,36 +169,29 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
);
}
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, receiveActivity);
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, activity);
await context.CloseAsync(
WebSocketCloseStatus.NormalClosure,
"Closed by client"
).ConfigureAwait(false);
await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by client").ConfigureAwait(false);
break;
}
// Enrich activity with message details
string? textContent = null;
if (result.MessageType == WebSocketMessageType.Text)
{
textContent = Encoding.UTF8.GetString(buffer, 0, result.Count);
}
var data = ToData(result, buffer);
if (shouldTrace)
{
WireMockActivitySource.EnrichWithWebSocketMessage(
receiveActivity,
activity,
result.MessageType,
result.Count,
result.EndOfMessage,
textContent,
data as string,
context.Options?.ActivityTracingOptions
);
}
// Log the receive operation
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, textContent, receiveActivity);
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, activity);
// Echo back (this will be logged by context.SendAsync)
await context.WebSocket.SendAsync(
@@ -206,15 +200,18 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
result.EndOfMessage,
cts.Token
).ConfigureAwait(false);
// Log the send (echo) operation
context.LogWebSocketMessage(WebSocketMessageDirection.Send, result.MessageType, data, activity);
}
catch (Exception ex)
{
WireMockActivitySource.RecordException(receiveActivity, ex);
WireMockActivitySource.RecordException(activity, ex);
throw;
}
finally
{
receiveActivity?.Dispose();
activity?.Dispose();
}
}
}
@@ -232,7 +229,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
Func<WebSocketMessage, IWebSocketContext, Task> handler)
{
var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize;
var buffer = new byte[bufferSize];
using var buffer = ArrayPool<byte>.Shared.Lease(bufferSize);
var timeout = context.Builder.CloseTimeout ?? TimeSpan.FromMinutes(WebSocketConstants.DefaultCloseTimeoutMinutes);
using var cts = new CancellationTokenSource(timeout);
@@ -350,18 +347,12 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
ClientWebSocket clientWebSocket,
WebSocketMessageDirection direction)
{
var buffer = new byte[WebSocketConstants.ProxyForwardBufferSize];
using var buffer = ArrayPool<byte>.Shared.Lease(WebSocketConstants.ProxyForwardBufferSize);
// Get options for activity tracing
var options = context.HttpContext.Items.TryGetValue(nameof(WireMockMiddlewareOptions), out var optionsObj) &&
optionsObj is IWireMockMiddlewareOptions wireMockOptions
? wireMockOptions
: null;
var shouldTrace = context.Options?.ActivityTracingOptions is not null;
var shouldTrace = options?.ActivityTracingOptions is not null;
var source = direction == WebSocketMessageDirection.Receive ? context.WebSocket : (WebSocket)clientWebSocket;
var destination = direction == WebSocketMessageDirection.Receive ? (WebSocket)clientWebSocket : context.WebSocket;
var source = direction == WebSocketMessageDirection.Receive ? context.WebSocket : clientWebSocket;
var destination = direction == WebSocketMessageDirection.Receive ? clientWebSocket : context.WebSocket;
while (source.State == WebSocketState.Open && destination.State == WebSocketState.Open)
{
@@ -385,7 +376,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
result.Count,
result.EndOfMessage,
null,
options?.ActivityTracingOptions
context.Options?.ActivityTracingOptions
);
}
@@ -400,16 +391,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
}
// Enrich activity with message details
object? data = null;
if (result.MessageType == WebSocketMessageType.Text)
{
data = Encoding.UTF8.GetString(buffer, 0, result.Count);
}
else if (result.MessageType == WebSocketMessageType.Binary)
{
data = new byte[result.Count];
Array.Copy(buffer, (byte[])data, result.Count);
}
var data = ToData(result, buffer);
if (shouldTrace)
{
@@ -419,7 +401,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
result.Count,
result.EndOfMessage,
data as string,
options?.ActivityTracingOptions
context.Options?.ActivityTracingOptions
);
}
@@ -483,17 +465,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
}
// Log the receive operation
object? data = null;
if (result.MessageType == WebSocketMessageType.Text)
{
data = Encoding.UTF8.GetString(buffer, 0, result.Count);
}
else if (result.MessageType == WebSocketMessageType.Binary)
{
data = new byte[result.Count];
Array.Copy(buffer, (byte[])data, result.Count);
}
var data = ToData(result, buffer);
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, data, receiveActivity);
if (result.MessageType == WebSocketMessageType.Close)
@@ -543,4 +515,22 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder) : IResponsePr
return message;
}
private static object? ToData(WebSocketReceiveResult result, byte[] buffer)
{
if (result.MessageType == WebSocketMessageType.Text)
{
return Encoding.UTF8.GetString(buffer, 0, result.Count);
}
if (result.MessageType == WebSocketMessageType.Binary)
{
var data = new byte[result.Count];
Array.Copy(buffer, data, result.Count);
return data;
}
return null;
}
}

View File

@@ -1,7 +1,5 @@
// Copyright © WireMock.Net
using System;
namespace WireMock.Util;
internal interface IGuidUtils

View File

@@ -1,11 +1,9 @@
// Copyright © WireMock.Net
using System.Diagnostics;
using System.Net;
using System.Net.WebSockets;
using System.Text;
using Microsoft.AspNetCore.Http;
using Stef.Validation;
using WireMock.Logging;
using WireMock.Models;
using WireMock.Owin;
@@ -21,7 +19,7 @@ namespace WireMock.WebSockets;
public class WireMockWebSocketContext : IWebSocketContext
{
/// <inheritdoc />
public Guid ConnectionId { get; } = Guid.NewGuid();
public Guid ConnectionId { get; }
/// <inheritdoc />
public HttpContext HttpContext { get; }
@@ -54,16 +52,20 @@ public class WireMockWebSocketContext : IWebSocketContext
WebSocketConnectionRegistry? registry,
WebSocketBuilder builder,
IWireMockMiddlewareOptions options,
IWireMockMiddlewareLogger logger)
IWireMockMiddlewareLogger logger,
IGuidUtils guidUtils
)
{
HttpContext = Guard.NotNull(httpContext);
WebSocket = Guard.NotNull(webSocket);
RequestMessage = Guard.NotNull(requestMessage);
Mapping = Guard.NotNull(mapping);
HttpContext = httpContext;
WebSocket = webSocket;
RequestMessage = requestMessage;
Mapping = mapping;
Registry = registry;
Builder = Guard.NotNull(builder);
Options = Guard.NotNull(options);
Logger = Guard.NotNull(logger);
Builder = builder;
Options = options;
Logger = logger;
ConnectionId = guidUtils.NewGuid();
}
/// <inheritdoc />
@@ -187,7 +189,6 @@ public class WireMockWebSocketContext : IWebSocketContext
object? data,
Activity? activity)
{
// Create body data
IBodyData bodyData;
if (messageType == WebSocketMessageType.Text && data is string textContent)
{
@@ -214,12 +215,11 @@ public class WireMockWebSocketContext : IWebSocketContext
};
}
// Create a pseudo-request or pseudo-response depending on direction
var method = $"WS_{direction.ToString().ToUpperInvariant()}";
RequestMessage? requestMessage = null;
IResponseMessage? responseMessage = null;
var method = $"WS_{direction.ToString().ToUpperInvariant()}";
if (direction == WebSocketMessageDirection.Receive)
{
// Received message - log as request
@@ -241,7 +241,6 @@ public class WireMockWebSocketContext : IWebSocketContext
responseMessage = new ResponseMessage
{
Method = method,
StatusCode = HttpStatusCode.SwitchingProtocols, // WebSocket status
BodyData = bodyData,
DateTime = DateTime.UtcNow
};

View File

@@ -1,30 +1,25 @@
// Copyright © WireMock.Net
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Linq.Expressions;
using AwesomeAssertions;
using Microsoft.AspNetCore.Http;
using Moq;
using WireMock.Models;
using WireMock.Owin;
using WireMock.Owin.Mappers;
using WireMock.Util;
using WireMock.Logging;
using WireMock.Matchers;
using WireMock.Admin.Mappings;
using WireMock.Admin.Requests;
using WireMock.Settings;
using AwesomeAssertions;
using WireMock.Handlers;
using WireMock.Logging;
using WireMock.Matchers;
using WireMock.Matchers.Request;
using WireMock.ResponseBuilders;
using WireMock.RequestBuilders;
using Microsoft.AspNetCore.Http;
using Microsoft.CodeAnalysis.CSharp.Syntax;
#if NET6_0_OR_GREATER
using WireMock.Models;
using WireMock.Owin;
using WireMock.Owin.ActivityTracing;
using System.Diagnostics;
#endif
using WireMock.Owin.Mappers;
using WireMock.RequestBuilders;
using WireMock.ResponseBuilders;
using WireMock.Settings;
using WireMock.Util;
namespace WireMock.Net.Tests.Owin;
@@ -41,13 +36,16 @@ public class WireMockMiddlewareTests
private readonly Mock<IMapping> _mappingMock;
private readonly Mock<IRequestMatchResult> _requestMatchResultMock;
private readonly Mock<HttpContext> _contextMock;
private readonly Mock<IGuidUtils> _guidUtilsMock;
private readonly WireMockMiddleware _sut;
public WireMockMiddlewareTests()
{
var wireMockMiddlewareLoggerMock = new Mock<IWireMockMiddlewareLogger>();
// wreMockMiddlewareLoggerMock.Setup(g => g.NewGuid()).Returns(NewGuid);
_guidUtilsMock = new Mock<IGuidUtils>();
_guidUtilsMock.Setup(g => g.NewGuid()).Returns(NewGuid);
_optionsMock = new Mock<IWireMockMiddlewareOptions>();
_optionsMock.SetupAllProperties();
@@ -86,7 +84,8 @@ public class WireMockMiddlewareTests
_requestMapperMock.Object,
_responseMapperMock.Object,
_matcherMock.Object,
wireMockMiddlewareLoggerMock.Object
wireMockMiddlewareLoggerMock.Object,
_guidUtilsMock.Object
);
}
@@ -262,7 +261,6 @@ public class WireMockMiddlewareTests
_mappings.Should().HaveCount(1);
}
#if NET6_0_OR_GREATER
[Fact]
public async Task WireMockMiddleware_Invoke_AdminPath_WithExcludeAdminRequests_ShouldNotStartActivity()
{
@@ -346,5 +344,4 @@ public class WireMockMiddlewareTests
// Assert
activityStarted.Should().BeFalse();
}
#endif
}