This commit is contained in:
Stef Heyenrath
2026-02-22 21:48:33 +01:00
parent 8287ae79ec
commit ba3d1d758c
9 changed files with 629 additions and 98 deletions

View File

@@ -48,7 +48,6 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
{ {
SubProtocol = builder.AcceptProtocol, SubProtocol = builder.AcceptProtocol,
KeepAliveInterval = builder.KeepAliveIntervalSeconds ?? TimeSpan.FromSeconds(WebSocketConstants.DefaultKeepAliveIntervalSeconds) KeepAliveInterval = builder.KeepAliveIntervalSeconds ?? TimeSpan.FromSeconds(WebSocketConstants.DefaultKeepAliveIntervalSeconds)
}; };
var webSocket = await context.WebSockets.AcceptWebSocketAsync(acceptContext).ConfigureAwait(false); var webSocket = await context.WebSockets.AcceptWebSocketAsync(acceptContext).ConfigureAwait(false);
#else #else
@@ -56,9 +55,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
#endif #endif
// Get or create registry from options // Get or create registry from options
var registry = builder.IsBroadcast var registry = options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry());
? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry())
: null;
// Create WebSocket context // Create WebSocket context
var wsContext = new WireMockWebSocketContext( var wsContext = new WireMockWebSocketContext(
@@ -73,8 +70,8 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
guidUtils guidUtils
); );
// Add to registry if broadcast is enabled // Add to registry
registry?.AddConnection(wsContext); registry.AddConnection(wsContext);
try try
{ {
@@ -100,7 +97,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
finally finally
{ {
// Remove from registry // Remove from registry
registry?.RemoveConnection(wsContext.ConnectionId); registry.RemoveConnection(wsContext.ConnectionId);
} }
// Return special marker to indicate WebSocket was handled // Return special marker to indicate WebSocket was handled
@@ -213,9 +210,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
} }
} }
private static async Task HandleCustomAsync( private static async Task HandleCustomAsync(WireMockWebSocketContext context, Func<WebSocketMessage, IWebSocketContext, Task> handler)
WireMockWebSocketContext context,
Func<WebSocketMessage, IWebSocketContext, Task> handler)
{ {
var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize; var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize;
using var buffer = ArrayPool<byte>.Shared.Lease(bufferSize); using var buffer = ArrayPool<byte>.Shared.Lease(bufferSize);
@@ -236,10 +231,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
try try
{ {
var result = await context.WebSocket.ReceiveAsync( var result = await context.WebSocket.ReceiveAsync(new ArraySegment<byte>(buffer), cts.Token).ConfigureAwait(false);
new ArraySegment<byte>(buffer),
cts.Token
).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close) if (result.MessageType == WebSocketMessageType.Close)
{ {
@@ -257,10 +249,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, receiveActivity); context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, receiveActivity);
await context.CloseAsync( await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by client").ConfigureAwait(false);
WebSocketCloseStatus.NormalClosure,
"Closed by client"
).ConfigureAwait(false);
break; break;
} }
@@ -331,10 +320,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu
} }
} }
private static async Task ForwardMessagesAsync( private static async Task ForwardMessagesAsync(WireMockWebSocketContext context, ClientWebSocket clientWebSocket, WebSocketMessageDirection direction)
WireMockWebSocketContext context,
ClientWebSocket clientWebSocket,
WebSocketMessageDirection direction)
{ {
using var buffer = ArrayPool<byte>.Shared.Lease(WebSocketConstants.ProxyForwardBufferSize); using var buffer = ArrayPool<byte>.Shared.Lease(WebSocketConstants.ProxyForwardBufferSize);

View File

@@ -1,6 +1,5 @@
// Copyright © WireMock.Net // Copyright © WireMock.Net
using System.Net.WebSockets;
using JetBrains.Annotations; using JetBrains.Annotations;
using WireMock.WebSockets; using WireMock.WebSockets;
@@ -32,17 +31,16 @@ public partial class WireMockServer
/// Close a specific WebSocket connection /// Close a specific WebSocket connection
/// </summary> /// </summary>
[PublicAPI] [PublicAPI]
public async Task CloseWebSocketConnectionAsync( public async Task AbortWebSocketConnectionAsync(Guid connectionId, string statusDescription = "Closed by server", CancellationToken cancellationToken = default)
Guid connectionId,
WebSocketCloseStatus closeStatus = WebSocketCloseStatus.NormalClosure,
string statusDescription = "Closed by server",
CancellationToken cancellationToken = default)
{ {
foreach (var registry in _options.WebSocketRegistries.Values) foreach (var registry in _options.WebSocketRegistries.Values)
{ {
if (registry.TryGetConnection(connectionId, out var connection) && !cancellationToken.IsCancellationRequested) if (registry.TryGetConnection(connectionId, out var connection))
{ {
await connection.CloseAsync(closeStatus, statusDescription, cancellationToken); connection.Abort(statusDescription);
registry.RemoveConnection(connectionId);
await Task.Delay(100, cancellationToken); // Give the connection some time to close gracefully
return; return;
} }
} }
@@ -52,11 +50,11 @@ public partial class WireMockServer
/// Broadcast a text message to all WebSocket connections in a specific mapping /// Broadcast a text message to all WebSocket connections in a specific mapping
/// </summary> /// </summary>
[PublicAPI] [PublicAPI]
public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, string text, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, string text, CancellationToken cancellationToken = default)
{ {
if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry)) if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry))
{ {
await registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); await registry.BroadcastAsync(text, null, cancellationToken);
} }
} }
@@ -64,11 +62,11 @@ public partial class WireMockServer
/// Broadcast a text message to all WebSocket connections /// Broadcast a text message to all WebSocket connections
/// </summary> /// </summary>
[PublicAPI] [PublicAPI]
public async Task BroadcastToAllWebSocketsAsync(string text, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) public async Task BroadcastToAllWebSocketsAsync(string text, CancellationToken cancellationToken = default)
{ {
foreach (var registry in _options.WebSocketRegistries.Values) foreach (var registry in _options.WebSocketRegistries.Values)
{ {
await registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); await registry.BroadcastAsync(text, null, cancellationToken);
} }
} }
@@ -76,11 +74,11 @@ public partial class WireMockServer
/// Broadcast a binary message to all WebSocket connections in a specific mapping /// Broadcast a binary message to all WebSocket connections in a specific mapping
/// </summary> /// </summary>
[PublicAPI] [PublicAPI]
public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, byte[] bytes, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, byte[] bytes, CancellationToken cancellationToken = default)
{ {
if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry)) if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry))
{ {
await registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); await registry.BroadcastAsync(bytes, null, cancellationToken);
} }
} }
@@ -88,11 +86,11 @@ public partial class WireMockServer
/// Broadcast a binary message to all WebSocket connections /// Broadcast a binary message to all WebSocket connections
/// </summary> /// </summary>
[PublicAPI] [PublicAPI]
public async Task BroadcastToAllWebSocketsAsync(byte[] bytes, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) public async Task BroadcastToAllWebSocketsAsync(byte[] bytes, CancellationToken cancellationToken = default)
{ {
foreach (var registry in _options.WebSocketRegistries.Values) foreach (var registry in _options.WebSocketRegistries.Values)
{ {
await registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); await registry.BroadcastAsync(bytes, null, cancellationToken);
} }
} }
} }

View File

@@ -13,34 +13,22 @@ internal class WebSocketBuilder(Response response) : IWebSocketBuilder
{ {
private readonly List<(IMatcher matcher, List<WebSocketMessageBuilder> messages)> _conditionalMessages = []; private readonly List<(IMatcher matcher, List<WebSocketMessageBuilder> messages)> _conditionalMessages = [];
/// <inheritdoc />
public string? AcceptProtocol { get; private set; } public string? AcceptProtocol { get; private set; }
/// <inheritdoc />
public bool IsEcho { get; private set; } public bool IsEcho { get; private set; }
/// <inheritdoc />
public bool IsBroadcast { get; private set; }
/// <inheritdoc />
public Func<WebSocketMessage, IWebSocketContext, Task>? MessageHandler { get; private set; } public Func<WebSocketMessage, IWebSocketContext, Task>? MessageHandler { get; private set; }
/// <inheritdoc />
public ProxyAndRecordSettings? ProxySettings { get; private set; } public ProxyAndRecordSettings? ProxySettings { get; private set; }
/// <inheritdoc />
public TimeSpan? CloseTimeout { get; private set; } public TimeSpan? CloseTimeout { get; private set; }
/// <inheritdoc />
public int? MaxMessageSize { get; private set; } public int? MaxMessageSize { get; private set; }
/// <inheritdoc />
public int? ReceiveBufferSize { get; private set; } public int? ReceiveBufferSize { get; private set; }
/// <inheritdoc />
public TimeSpan? KeepAliveIntervalSeconds { get; private set; } public TimeSpan? KeepAliveIntervalSeconds { get; private set; }
/// <inheritdoc />
public IWebSocketBuilder WithAcceptProtocol(string protocol) public IWebSocketBuilder WithAcceptProtocol(string protocol)
{ {
AcceptProtocol = Guard.NotNull(protocol); AcceptProtocol = Guard.NotNull(protocol);
@@ -117,12 +105,6 @@ internal class WebSocketBuilder(Response response) : IWebSocketBuilder
return this; return this;
} }
public IWebSocketBuilder WithBroadcast()
{
IsBroadcast = true;
return this;
}
public IWebSocketBuilder WithProxy(ProxyAndRecordSettings settings) public IWebSocketBuilder WithProxy(ProxyAndRecordSettings settings)
{ {
ProxySettings = Guard.NotNull(settings); ProxySettings = Guard.NotNull(settings);

View File

@@ -3,7 +3,6 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Net.WebSockets; using System.Net.WebSockets;
using static System.Net.Mime.MediaTypeNames;
namespace WireMock.WebSockets; namespace WireMock.WebSockets;
@@ -27,7 +26,7 @@ internal class WebSocketConnectionRegistry
/// </summary> /// </summary>
public void RemoveConnection(Guid connectionId) public void RemoveConnection(Guid connectionId)
{ {
_connections.TryRemove(connectionId, out _); _ = _connections.TryRemove(connectionId, out _);
} }
/// <summary> /// <summary>
@@ -67,6 +66,6 @@ internal class WebSocketConnectionRegistry
private IEnumerable<WireMockWebSocketContext> Filter(Guid? excludeConnectionId) private IEnumerable<WireMockWebSocketContext> Filter(Guid? excludeConnectionId)
{ {
return _connections.Values return _connections.Values
.Where(c =>c.WebSocket.State == WebSocketState.Open && (!excludeConnectionId.HasValue || c.ConnectionId != excludeConnectionId)); .Where(c => c.WebSocket.State == WebSocketState.Open && (!excludeConnectionId.HasValue || c.ConnectionId != excludeConnectionId));
} }
} }

View File

@@ -1,5 +1,6 @@
// Copyright © WireMock.Net // Copyright © WireMock.Net
using System.Buffers;
using System.Diagnostics; using System.Diagnostics;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text; using System.Text;
@@ -33,7 +34,7 @@ public class WireMockWebSocketContext : IWebSocketContext
/// <inheritdoc /> /// <inheritdoc />
public IMapping Mapping { get; } public IMapping Mapping { get; }
internal WebSocketConnectionRegistry? Registry { get; } internal WebSocketConnectionRegistry Registry { get; }
internal WebSocketBuilder Builder { get; } internal WebSocketBuilder Builder { get; }
@@ -49,7 +50,7 @@ public class WireMockWebSocketContext : IWebSocketContext
WebSocket webSocket, WebSocket webSocket,
IRequestMessage requestMessage, IRequestMessage requestMessage,
IMapping mapping, IMapping mapping,
WebSocketConnectionRegistry? registry, WebSocketConnectionRegistry registry,
WebSocketBuilder builder, WebSocketBuilder builder,
IWireMockMiddlewareOptions options, IWireMockMiddlewareOptions options,
IWireMockMiddlewareLogger logger, IWireMockMiddlewareLogger logger,
@@ -96,29 +97,31 @@ public class WireMockWebSocketContext : IWebSocketContext
/// <inheritdoc /> /// <inheritdoc />
public async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default) public async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default)
{ {
await WebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); await WebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);
LogWebSocketMessage(WebSocketMessageDirection.Send, WebSocketMessageType.Close, $"CloseStatus: {closeStatus}, Description: {statusDescription}", null); LogWebSocketMessage(WebSocketMessageDirection.Send, WebSocketMessageType.Close, $"CloseStatus: {closeStatus}, Description: {statusDescription}", null);
} }
/// <inheritdoc />
public void Abort(string? statusDescription = null)
{
WebSocket.Abort();
LogWebSocketMessage(WebSocketMessageDirection.Send, WebSocketMessageType.Close, $"CloseStatus: Abort, Description: {statusDescription}", null);
}
/// <inheritdoc /> /// <inheritdoc />
public async Task BroadcastAsync(string text, bool excludeSender = false, CancellationToken cancellationToken = default) public async Task BroadcastAsync(string text, bool excludeSender = false, CancellationToken cancellationToken = default)
{ {
if (Registry != null) Guid? excludeConnectionId = excludeSender ? ConnectionId : null;
{ await Registry.BroadcastAsync(text, excludeConnectionId, cancellationToken);
Guid? excludeConnectionId = excludeSender ? ConnectionId : null;
await Registry.BroadcastAsync(text, excludeConnectionId, cancellationToken);
}
} }
/// <inheritdoc /> /// <inheritdoc />
public async Task BroadcastAsync(byte[] bytes, bool excludeSender = false, CancellationToken cancellationToken = default) public async Task BroadcastAsync(byte[] bytes, bool excludeSender = false, CancellationToken cancellationToken = default)
{ {
if (Registry != null) Guid? excludeConnectionId = excludeSender ? ConnectionId : null;
{ await Registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken);
Guid? excludeConnectionId = excludeSender ? ConnectionId : null;
await Registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken);
}
} }
internal void LogWebSocketMessage( internal void LogWebSocketMessage(

View File

@@ -64,12 +64,6 @@ public interface IWebSocketBuilder
[PublicAPI] [PublicAPI]
IWebSocketBuilder WithMessageHandler(Func<WebSocketMessage, IWebSocketContext, Task> handler); IWebSocketBuilder WithMessageHandler(Func<WebSocketMessage, IWebSocketContext, Task> handler);
/// <summary>
/// Enable broadcast mode for this mapping
/// </summary>
[PublicAPI]
IWebSocketBuilder WithBroadcast();
/// <summary> /// <summary>
/// Proxy to another WebSocket server /// Proxy to another WebSocket server
/// </summary> /// </summary>

View File

@@ -50,6 +50,11 @@ public interface IWebSocketContext
/// </summary> /// </summary>
Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default); Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default);
/// <summary>
/// Abort the WebSocket connection to immediately close the connection without waiting for the close handshake
/// </summary>
void Abort(string? statusDescription = null);
/// <summary> /// <summary>
/// Broadcast text message to all connections in this mapping /// Broadcast text message to all connections in this mapping
/// </summary> /// </summary>

View File

@@ -400,29 +400,29 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
if (text.StartsWith("/help")) if (text.StartsWith("/help"))
{ {
await context.SendAsync("Available commands: /help, /time, /echo <text>, /upper <text>, /reverse <text>"); await context.SendAsync("Available commands: /help, /time, /echo <text>, /upper <text>, /reverse <text>", _ct);
} }
else if (text.StartsWith("/time")) else if (text.StartsWith("/time"))
{ {
await context.SendAsync($"Server time: {DateTime.UtcNow:yyyy-MM-dd HH:mm:ss} UTC"); await context.SendAsync($"Server time: {DateTime.UtcNow:yyyy-MM-dd HH:mm:ss} UTC", _ct);
} }
else if (text.StartsWith("/echo ")) else if (text.StartsWith("/echo "))
{ {
await context.SendAsync(text.Substring(6)); await context.SendAsync(text.Substring(6), _ct);
} }
else if (text.StartsWith("/upper ")) else if (text.StartsWith("/upper "))
{ {
await context.SendAsync(text.Substring(7).ToUpper()); await context.SendAsync(text.Substring(7).ToUpper(), _ct);
} }
else if (text.StartsWith("/reverse ")) else if (text.StartsWith("/reverse "))
{ {
var toReverse = text.Substring(9); var toReverse = text.Substring(9);
var reversed = new string(toReverse.Reverse().ToArray()); var reversed = new string(toReverse.Reverse().ToArray());
await context.SendAsync(reversed); await context.SendAsync(reversed, _ct);
} }
else if (text.StartsWith("/close")) else if (text.StartsWith("/close"))
{ {
await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing connection"); await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing connection", _ct);
} }
} }
}) })
@@ -800,8 +800,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(10)) .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithBroadcast()
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text) if (message.MessageType == WebSocketMessageType.Text)
@@ -885,8 +884,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(10)) .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithBroadcast()
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text && message.Text == "register") if (message.MessageType == WebSocketMessageType.Text && message.Text == "register")
@@ -962,8 +960,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithBroadcast() .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithCloseTimeout(TimeSpan.FromSeconds(10))
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text) if (message.MessageType == WebSocketMessageType.Text)
@@ -1020,8 +1017,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithBroadcast() .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithCloseTimeout(TimeSpan.FromSeconds(10))
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text) if (message.MessageType == WebSocketMessageType.Text)
@@ -1087,8 +1083,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(10)) .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithBroadcast()
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text) if (message.MessageType == WebSocketMessageType.Text)
@@ -1135,8 +1130,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(10)) .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithBroadcast()
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text) if (message.MessageType == WebSocketMessageType.Text)
@@ -1188,8 +1182,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc
) )
.RespondWith(Response.Create() .RespondWith(Response.Create()
.WithWebSocket(ws => ws .WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(10)) .WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithBroadcast()
.WithMessageHandler(async (message, context) => .WithMessageHandler(async (message, context) =>
{ {
if (message.MessageType == WebSocketMessageType.Text) if (message.MessageType == WebSocketMessageType.Text)

View File

@@ -0,0 +1,571 @@
// Copyright © WireMock.Net
using System.Net.WebSockets;
using AwesomeAssertions;
using WireMock.Net.Xunit;
using WireMock.RequestBuilders;
using WireMock.ResponseBuilders;
using WireMock.Server;
using WireMock.Settings;
namespace WireMock.Net.Tests.WebSockets;
public class WireMockServerWebSocketIntegrationTests(ITestOutputHelper output, ITestContextAccessor testContext)
{
private readonly CancellationToken _ct = testContext.Current.CancellationToken;
[Fact]
public async Task GetWebSocketConnections_Should_Return_All_Active_Connections()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
server
.Given(Request.Create()
.WithPath("/ws/test")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithEcho()
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
using var client3 = new ClientWebSocket();
var uri = new Uri($"{server.Url}/ws/test");
// Act
await client1.ConnectAsync(uri, _ct);
await client2.ConnectAsync(uri, _ct);
await client3.ConnectAsync(uri, _ct);
// Assert
var connections = server.GetWebSocketConnections();
connections.Should().HaveCount(3);
connections.Should().AllSatisfy(c => c.Should().NotBeNull());
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client3.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(300, _ct);
}
[Fact]
public async Task GetWebSocketConnections_Should_Return_Empty_When_No_Connections()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
server
.Given(Request.Create()
.WithPath("/ws/test")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithEcho()
)
);
// Act
var connections = server.GetWebSocketConnections();
// Assert
connections.Should().BeEmpty();
}
[Fact]
public async Task GetWebSocketConnections_Should_Return_Connections_For_Specific_Mapping()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
var mapping1Guid = Guid.NewGuid();
var mapping2Guid = Guid.NewGuid();
server
.Given(Request.Create()
.WithPath("/ws/echo1")
.WithWebSocketUpgrade()
)
.WithGuid(mapping1Guid)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithEcho()
)
);
server
.Given(Request.Create()
.WithPath("/ws/echo2")
.WithWebSocketUpgrade()
)
.WithGuid(mapping2Guid)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithEcho()
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
using var client3 = new ClientWebSocket();
var uri1 = new Uri($"{server.Url}/ws/echo1");
var uri2 = new Uri($"{server.Url}/ws/echo2");
// Act
await client1.ConnectAsync(uri1, _ct);
await client2.ConnectAsync(uri1, _ct);
await client3.ConnectAsync(uri2, _ct);
// Assert
var allConnections = server.GetWebSocketConnections();
allConnections.Should().HaveCount(3);
var mapping1Connections = server.GetWebSocketConnections(mapping1Guid);
mapping1Connections.Should().HaveCount(2);
var mapping2Connections = server.GetWebSocketConnections(mapping2Guid);
mapping2Connections.Should().HaveCount(1);
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client3.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(300, _ct);
}
[Fact]
public async Task AbortWebSocketConnectionAsync_Should_Close_Specific_Connection()
{
// Arrange
var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
server
.Given(Request.Create()
.WithPath("/ws/test")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(30))
.WithEcho()
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
var uri = new Uri($"{server.Url}/ws/test");
await client1.ConnectAsync(uri, _ct);
await client2.ConnectAsync(uri, _ct);
var connections = server.GetWebSocketConnections();
connections.Should().HaveCount(2);
var connectionIdToAbort = connections.First().ConnectionId;
// Act
await server.AbortWebSocketConnectionAsync(connectionIdToAbort, "Abort by test", _ct);
// Assert
var remainingConnections = server.GetWebSocketConnections();
remainingConnections.Should().HaveCount(1);
var remainingConnection = remainingConnections.First();
remainingConnection.ConnectionId.Should().NotBe(connectionIdToAbort);
await Task.Delay(200, _ct);
}
[Fact]
public async Task BroadcastToWebSocketsAsync_Should_Broadcast_Text_To_Specific_Mapping()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
var broadcastMessage = "Server broadcast message";
var mappingGuid = Guid.NewGuid();
server
.Given(Request.Create()
.WithPath("/ws/broadcast")
.WithWebSocketUpgrade()
)
.WithGuid(mappingGuid)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithMessageHandler(async (message, context) =>
{
if (message.MessageType == WebSocketMessageType.Text)
{
var text = message.Text ?? string.Empty;
if (text.StartsWith("ready"))
{
await context.SendAsync("ready!");
}
}
})
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
var uri = new Uri($"{server.Url}/ws/broadcast");
await client1.ConnectAsync(uri, _ct);
await client2.ConnectAsync(uri, _ct);
// Signal ready
await client1.SendAsync("ready", cancellationToken: _ct);
await client2.SendAsync("ready", cancellationToken: _ct);
var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct);
var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct);
text1.Should().Be("ready!");
text2.Should().Be("ready!");
// Act
await server.BroadcastToWebSocketsAsync(mappingGuid, broadcastMessage, cancellationToken: _ct);
// Assert
var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct);
var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct);
received1.Should().Be(broadcastMessage);
received2.Should().Be(broadcastMessage);
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(200, _ct);
}
[Fact]
public async Task BroadcastToWebSocketsAsync_Should_Broadcast_Binary_To_Specific_Mapping()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
var broadcastData = new byte[] { 0x01, 0x02, 0x03, 0x04 };
var mappingGuid = Guid.NewGuid();
server
.Given(Request.Create()
.WithPath("/ws/broadcast-binary")
.WithWebSocketUpgrade()
)
.WithGuid(mappingGuid)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithMessageHandler(async (message, context) =>
{
if (message.MessageType == WebSocketMessageType.Text)
{
var text = message.Text ?? string.Empty;
if (text.StartsWith("ready"))
{
await context.SendAsync("ready!");
}
}
})
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
var uri = new Uri($"{server.Url}/ws/broadcast-binary");
await client1.ConnectAsync(uri, _ct);
await client2.ConnectAsync(uri, _ct);
// Signal ready
await client1.SendAsync("ready", cancellationToken: _ct);
await client2.SendAsync("ready", cancellationToken: _ct);
var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct);
var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct);
text1.Should().Be("ready!");
text2.Should().Be("ready!");
// Act
await server.BroadcastToWebSocketsAsync(mappingGuid, broadcastData, cancellationToken: _ct);
// Assert
var received1 = await client1.ReceiveAsBytesAsync(cancellationToken: _ct);
var received2 = await client2.ReceiveAsBytesAsync(cancellationToken: _ct);
received1.Should().BeEquivalentTo(broadcastData);
received2.Should().BeEquivalentTo(broadcastData);
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(200, _ct);
}
[Fact]
public async Task BroadcastToAllWebSocketsAsync_Should_Broadcast_Text_To_All_Mappings()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
var broadcastMessage = "Broadcast to all mappings";
server
.Given(Request.Create()
.WithPath("/ws/mapping1")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithMessageHandler(async (message, context) =>
{
if (message.MessageType == WebSocketMessageType.Text)
{
var text = message.Text ?? string.Empty;
if (text.StartsWith("ready"))
{
await context.SendAsync("ready!");
}
}
})
)
);
server
.Given(Request.Create()
.WithPath("/ws/mapping2")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithMessageHandler(async (message, context) =>
{
if (message.MessageType == WebSocketMessageType.Text)
{
var text = message.Text ?? string.Empty;
if (text.StartsWith("ready"))
{
await context.SendAsync("ready!");
}
}
})
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
var uri1 = new Uri($"{server.Url}/ws/mapping1");
var uri2 = new Uri($"{server.Url}/ws/mapping2");
await client1.ConnectAsync(uri1, _ct);
await client2.ConnectAsync(uri2, _ct);
// Signal ready
await client1.SendAsync("ready", cancellationToken: _ct);
await client2.SendAsync("ready", cancellationToken: _ct);
var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct);
var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct);
text1.Should().Be("ready!");
text2.Should().Be("ready!");
// Act
await server.BroadcastToAllWebSocketsAsync(broadcastMessage, cancellationToken: _ct);
// Assert - both clients from different mappings should receive the broadcast
var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct);
var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct);
received1.Should().Be(broadcastMessage);
received2.Should().Be(broadcastMessage);
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(200, _ct);
}
[Fact]
public async Task BroadcastToAllWebSocketsAsync_Should_Broadcast_Binary_To_All_Mappings()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
var broadcastData = new byte[] { 0xAA, 0xBB, 0xCC };
server
.Given(Request.Create()
.WithPath("/ws/mapping1")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithMessageHandler(async (message, context) =>
{
if (message.MessageType == WebSocketMessageType.Text)
{
var text = message.Text ?? string.Empty;
if (text.StartsWith("ready"))
{
await context.SendAsync("ready!");
}
}
})
)
);
server
.Given(Request.Create()
.WithPath("/ws/mapping2")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithMessageHandler(async (message, context) =>
{
if (message.MessageType == WebSocketMessageType.Text)
{
var text = message.Text ?? string.Empty;
if (text.StartsWith("ready"))
{
await context.SendAsync("ready!");
}
}
})
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
var uri1 = new Uri($"{server.Url}/ws/mapping1");
var uri2 = new Uri($"{server.Url}/ws/mapping2");
await client1.ConnectAsync(uri1, _ct);
await client2.ConnectAsync(uri2, _ct);
// Signal ready
await client1.SendAsync("ready", cancellationToken: _ct);
await client2.SendAsync("ready", cancellationToken: _ct);
var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct);
var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct);
text1.Should().Be("ready!");
text2.Should().Be("ready!");
// Act
await server.BroadcastToAllWebSocketsAsync(broadcastData, cancellationToken: _ct);
// Assert
var received1 = await client1.ReceiveAsBytesAsync(cancellationToken: _ct);
var received2 = await client2.ReceiveAsBytesAsync(cancellationToken: _ct);
received1.Should().BeEquivalentTo(broadcastData);
received2.Should().BeEquivalentTo(broadcastData);
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(200, _ct);
}
[Fact]
public async Task GetWebSocketConnections_Should_Update_After_Client_Disconnect()
{
// Arrange
using var server = WireMockServer.Start(new WireMockServerSettings
{
Logger = new TestOutputHelperWireMockLogger(output),
Urls = ["ws://localhost:0"]
});
server
.Given(Request.Create()
.WithPath("/ws/test")
.WithWebSocketUpgrade()
)
.RespondWith(Response.Create()
.WithWebSocket(ws => ws
.WithCloseTimeout(TimeSpan.FromSeconds(5))
.WithEcho()
)
);
using var client1 = new ClientWebSocket();
using var client2 = new ClientWebSocket();
var uri = new Uri($"{server.Url}/ws/test");
await client1.ConnectAsync(uri, _ct);
await client2.ConnectAsync(uri, _ct);
var initialConnections = server.GetWebSocketConnections();
initialConnections.Should().HaveCount(2);
// Act
await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnect", _ct);
await Task.Delay(100, _ct);
// Assert
var remainingConnections = server.GetWebSocketConnections();
remainingConnections.Should().HaveCount(1);
await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct);
await Task.Delay(200, _ct);
}
}