From ba3d1d758c90118f2608ea37a473341dacd9189e Mon Sep 17 00:00:00 2001 From: Stef Heyenrath Date: Sun, 22 Feb 2026 21:48:33 +0100 Subject: [PATCH] ok --- .../WebSocketResponseProvider.cs | 30 +- .../Server/WireMockServer.WebSocket.cs | 30 +- .../WebSockets/WebSocketBuilder.cs | 18 - .../WebSockets/WebSocketConnectionRegistry.cs | 5 +- .../WebSockets/WireMockWebSocketContext.cs | 29 +- .../WebSockets/IWebSocketBuilder.cs | 6 - .../WebSockets/IWebSocketContext.cs | 5 + .../WebSockets/WebSocketIntegrationTests.cs | 33 +- ...WireMockServerWebSocketIntegrationTests.cs | 571 ++++++++++++++++++ 9 files changed, 629 insertions(+), 98 deletions(-) create mode 100644 test/WireMock.Net.Tests/WebSockets/WireMockServerWebSocketIntegrationTests.cs diff --git a/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs b/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs index 286ada00..a73a0cf9 100644 --- a/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs +++ b/src/WireMock.Net.Minimal/ResponseProviders/WebSocketResponseProvider.cs @@ -48,7 +48,6 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu { SubProtocol = builder.AcceptProtocol, KeepAliveInterval = builder.KeepAliveIntervalSeconds ?? TimeSpan.FromSeconds(WebSocketConstants.DefaultKeepAliveIntervalSeconds) - }; var webSocket = await context.WebSockets.AcceptWebSocketAsync(acceptContext).ConfigureAwait(false); #else @@ -56,9 +55,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu #endif // Get or create registry from options - var registry = builder.IsBroadcast - ? options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry()) - : null; + var registry = options.WebSocketRegistries.GetOrAdd(mapping.Guid, _ => new WebSocketConnectionRegistry()); // Create WebSocket context var wsContext = new WireMockWebSocketContext( @@ -73,8 +70,8 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu guidUtils ); - // Add to registry if broadcast is enabled - registry?.AddConnection(wsContext); + // Add to registry + registry.AddConnection(wsContext); try { @@ -100,7 +97,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu finally { // Remove from registry - registry?.RemoveConnection(wsContext.ConnectionId); + registry.RemoveConnection(wsContext.ConnectionId); } // Return special marker to indicate WebSocket was handled @@ -213,9 +210,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu } } - private static async Task HandleCustomAsync( - WireMockWebSocketContext context, - Func handler) + private static async Task HandleCustomAsync(WireMockWebSocketContext context, Func handler) { var bufferSize = context.Builder.MaxMessageSize ?? WebSocketConstants.DefaultReceiveBufferSize; using var buffer = ArrayPool.Shared.Lease(bufferSize); @@ -236,10 +231,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu try { - var result = await context.WebSocket.ReceiveAsync( - new ArraySegment(buffer), - cts.Token - ).ConfigureAwait(false); + var result = await context.WebSocket.ReceiveAsync(new ArraySegment(buffer), cts.Token).ConfigureAwait(false); if (result.MessageType == WebSocketMessageType.Close) { @@ -257,10 +249,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu context.LogWebSocketMessage(WebSocketMessageDirection.Receive, result.MessageType, null, receiveActivity); - await context.CloseAsync( - WebSocketCloseStatus.NormalClosure, - "Closed by client" - ).ConfigureAwait(false); + await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by client").ConfigureAwait(false); break; } @@ -331,10 +320,7 @@ internal class WebSocketResponseProvider(WebSocketBuilder builder, IGuidUtils gu } } - private static async Task ForwardMessagesAsync( - WireMockWebSocketContext context, - ClientWebSocket clientWebSocket, - WebSocketMessageDirection direction) + private static async Task ForwardMessagesAsync(WireMockWebSocketContext context, ClientWebSocket clientWebSocket, WebSocketMessageDirection direction) { using var buffer = ArrayPool.Shared.Lease(WebSocketConstants.ProxyForwardBufferSize); diff --git a/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs b/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs index cf079fee..fd202079 100644 --- a/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs +++ b/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs @@ -1,6 +1,5 @@ // Copyright © WireMock.Net -using System.Net.WebSockets; using JetBrains.Annotations; using WireMock.WebSockets; @@ -32,17 +31,16 @@ public partial class WireMockServer /// Close a specific WebSocket connection /// [PublicAPI] - public async Task CloseWebSocketConnectionAsync( - Guid connectionId, - WebSocketCloseStatus closeStatus = WebSocketCloseStatus.NormalClosure, - string statusDescription = "Closed by server", - CancellationToken cancellationToken = default) + public async Task AbortWebSocketConnectionAsync(Guid connectionId, string statusDescription = "Closed by server", CancellationToken cancellationToken = default) { foreach (var registry in _options.WebSocketRegistries.Values) { - if (registry.TryGetConnection(connectionId, out var connection) && !cancellationToken.IsCancellationRequested) + if (registry.TryGetConnection(connectionId, out var connection)) { - await connection.CloseAsync(closeStatus, statusDescription, cancellationToken); + connection.Abort(statusDescription); + registry.RemoveConnection(connectionId); + + await Task.Delay(100, cancellationToken); // Give the connection some time to close gracefully return; } } @@ -52,11 +50,11 @@ public partial class WireMockServer /// Broadcast a text message to all WebSocket connections in a specific mapping /// [PublicAPI] - public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, string text, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) + public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, string text, CancellationToken cancellationToken = default) { if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry)) { - await registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); + await registry.BroadcastAsync(text, null, cancellationToken); } } @@ -64,11 +62,11 @@ public partial class WireMockServer /// Broadcast a text message to all WebSocket connections /// [PublicAPI] - public async Task BroadcastToAllWebSocketsAsync(string text, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) + public async Task BroadcastToAllWebSocketsAsync(string text, CancellationToken cancellationToken = default) { foreach (var registry in _options.WebSocketRegistries.Values) { - await registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); + await registry.BroadcastAsync(text, null, cancellationToken); } } @@ -76,11 +74,11 @@ public partial class WireMockServer /// Broadcast a binary message to all WebSocket connections in a specific mapping /// [PublicAPI] - public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, byte[] bytes, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) + public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, byte[] bytes, CancellationToken cancellationToken = default) { if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry)) { - await registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); + await registry.BroadcastAsync(bytes, null, cancellationToken); } } @@ -88,11 +86,11 @@ public partial class WireMockServer /// Broadcast a binary message to all WebSocket connections /// [PublicAPI] - public async Task BroadcastToAllWebSocketsAsync(byte[] bytes, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) + public async Task BroadcastToAllWebSocketsAsync(byte[] bytes, CancellationToken cancellationToken = default) { foreach (var registry in _options.WebSocketRegistries.Values) { - await registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); + await registry.BroadcastAsync(bytes, null, cancellationToken); } } } \ No newline at end of file diff --git a/src/WireMock.Net.Minimal/WebSockets/WebSocketBuilder.cs b/src/WireMock.Net.Minimal/WebSockets/WebSocketBuilder.cs index c51794d6..341cbb5a 100644 --- a/src/WireMock.Net.Minimal/WebSockets/WebSocketBuilder.cs +++ b/src/WireMock.Net.Minimal/WebSockets/WebSocketBuilder.cs @@ -13,34 +13,22 @@ internal class WebSocketBuilder(Response response) : IWebSocketBuilder { private readonly List<(IMatcher matcher, List messages)> _conditionalMessages = []; - /// public string? AcceptProtocol { get; private set; } - /// public bool IsEcho { get; private set; } - /// - public bool IsBroadcast { get; private set; } - - /// public Func? MessageHandler { get; private set; } - /// public ProxyAndRecordSettings? ProxySettings { get; private set; } - /// public TimeSpan? CloseTimeout { get; private set; } - /// public int? MaxMessageSize { get; private set; } - /// public int? ReceiveBufferSize { get; private set; } - /// public TimeSpan? KeepAliveIntervalSeconds { get; private set; } - /// public IWebSocketBuilder WithAcceptProtocol(string protocol) { AcceptProtocol = Guard.NotNull(protocol); @@ -117,12 +105,6 @@ internal class WebSocketBuilder(Response response) : IWebSocketBuilder return this; } - public IWebSocketBuilder WithBroadcast() - { - IsBroadcast = true; - return this; - } - public IWebSocketBuilder WithProxy(ProxyAndRecordSettings settings) { ProxySettings = Guard.NotNull(settings); diff --git a/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs b/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs index 0eac1867..ca7ba781 100644 --- a/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs +++ b/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs @@ -3,7 +3,6 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Net.WebSockets; -using static System.Net.Mime.MediaTypeNames; namespace WireMock.WebSockets; @@ -27,7 +26,7 @@ internal class WebSocketConnectionRegistry /// public void RemoveConnection(Guid connectionId) { - _connections.TryRemove(connectionId, out _); + _ = _connections.TryRemove(connectionId, out _); } /// @@ -67,6 +66,6 @@ internal class WebSocketConnectionRegistry private IEnumerable Filter(Guid? excludeConnectionId) { return _connections.Values - .Where(c =>c.WebSocket.State == WebSocketState.Open && (!excludeConnectionId.HasValue || c.ConnectionId != excludeConnectionId)); + .Where(c => c.WebSocket.State == WebSocketState.Open && (!excludeConnectionId.HasValue || c.ConnectionId != excludeConnectionId)); } } \ No newline at end of file diff --git a/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs b/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs index a0260ff4..7ff6d144 100644 --- a/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs +++ b/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs @@ -1,5 +1,6 @@ // Copyright © WireMock.Net +using System.Buffers; using System.Diagnostics; using System.Net.WebSockets; using System.Text; @@ -33,7 +34,7 @@ public class WireMockWebSocketContext : IWebSocketContext /// public IMapping Mapping { get; } - internal WebSocketConnectionRegistry? Registry { get; } + internal WebSocketConnectionRegistry Registry { get; } internal WebSocketBuilder Builder { get; } @@ -49,7 +50,7 @@ public class WireMockWebSocketContext : IWebSocketContext WebSocket webSocket, IRequestMessage requestMessage, IMapping mapping, - WebSocketConnectionRegistry? registry, + WebSocketConnectionRegistry registry, WebSocketBuilder builder, IWireMockMiddlewareOptions options, IWireMockMiddlewareLogger logger, @@ -96,29 +97,31 @@ public class WireMockWebSocketContext : IWebSocketContext /// public async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default) { - await WebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); + await WebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false); LogWebSocketMessage(WebSocketMessageDirection.Send, WebSocketMessageType.Close, $"CloseStatus: {closeStatus}, Description: {statusDescription}", null); } + /// + public void Abort(string? statusDescription = null) + { + WebSocket.Abort(); + + LogWebSocketMessage(WebSocketMessageDirection.Send, WebSocketMessageType.Close, $"CloseStatus: Abort, Description: {statusDescription}", null); + } + /// public async Task BroadcastAsync(string text, bool excludeSender = false, CancellationToken cancellationToken = default) { - if (Registry != null) - { - Guid? excludeConnectionId = excludeSender ? ConnectionId : null; - await Registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); - } + Guid? excludeConnectionId = excludeSender ? ConnectionId : null; + await Registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); } /// public async Task BroadcastAsync(byte[] bytes, bool excludeSender = false, CancellationToken cancellationToken = default) { - if (Registry != null) - { - Guid? excludeConnectionId = excludeSender ? ConnectionId : null; - await Registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); - } + Guid? excludeConnectionId = excludeSender ? ConnectionId : null; + await Registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); } internal void LogWebSocketMessage( diff --git a/src/WireMock.Net.Shared/WebSockets/IWebSocketBuilder.cs b/src/WireMock.Net.Shared/WebSockets/IWebSocketBuilder.cs index 3683c686..c830fb48 100644 --- a/src/WireMock.Net.Shared/WebSockets/IWebSocketBuilder.cs +++ b/src/WireMock.Net.Shared/WebSockets/IWebSocketBuilder.cs @@ -64,12 +64,6 @@ public interface IWebSocketBuilder [PublicAPI] IWebSocketBuilder WithMessageHandler(Func handler); - /// - /// Enable broadcast mode for this mapping - /// - [PublicAPI] - IWebSocketBuilder WithBroadcast(); - /// /// Proxy to another WebSocket server /// diff --git a/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs b/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs index 3df9d1c3..f4aa913b 100644 --- a/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs +++ b/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs @@ -50,6 +50,11 @@ public interface IWebSocketContext /// Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default); + /// + /// Abort the WebSocket connection to immediately close the connection without waiting for the close handshake + /// + void Abort(string? statusDescription = null); + /// /// Broadcast text message to all connections in this mapping /// diff --git a/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs b/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs index 2a588179..1785d30e 100644 --- a/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs +++ b/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs @@ -400,29 +400,29 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc if (text.StartsWith("/help")) { - await context.SendAsync("Available commands: /help, /time, /echo , /upper , /reverse "); + await context.SendAsync("Available commands: /help, /time, /echo , /upper , /reverse ", _ct); } else if (text.StartsWith("/time")) { - await context.SendAsync($"Server time: {DateTime.UtcNow:yyyy-MM-dd HH:mm:ss} UTC"); + await context.SendAsync($"Server time: {DateTime.UtcNow:yyyy-MM-dd HH:mm:ss} UTC", _ct); } else if (text.StartsWith("/echo ")) { - await context.SendAsync(text.Substring(6)); + await context.SendAsync(text.Substring(6), _ct); } else if (text.StartsWith("/upper ")) { - await context.SendAsync(text.Substring(7).ToUpper()); + await context.SendAsync(text.Substring(7).ToUpper(), _ct); } else if (text.StartsWith("/reverse ")) { var toReverse = text.Substring(9); var reversed = new string(toReverse.Reverse().ToArray()); - await context.SendAsync(reversed); + await context.SendAsync(reversed, _ct); } else if (text.StartsWith("/close")) { - await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing connection"); + await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing connection", _ct); } } }) @@ -800,8 +800,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithCloseTimeout(TimeSpan.FromSeconds(10)) - .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text) @@ -885,8 +884,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithCloseTimeout(TimeSpan.FromSeconds(10)) - .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text && message.Text == "register") @@ -962,8 +960,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithBroadcast() - .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text) @@ -1020,8 +1017,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithBroadcast() - .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text) @@ -1087,8 +1083,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithCloseTimeout(TimeSpan.FromSeconds(10)) - .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text) @@ -1135,8 +1130,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithCloseTimeout(TimeSpan.FromSeconds(10)) - .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text) @@ -1188,8 +1182,7 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc ) .RespondWith(Response.Create() .WithWebSocket(ws => ws - .WithCloseTimeout(TimeSpan.FromSeconds(10)) - .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(5)) .WithMessageHandler(async (message, context) => { if (message.MessageType == WebSocketMessageType.Text) diff --git a/test/WireMock.Net.Tests/WebSockets/WireMockServerWebSocketIntegrationTests.cs b/test/WireMock.Net.Tests/WebSockets/WireMockServerWebSocketIntegrationTests.cs new file mode 100644 index 00000000..232f423b --- /dev/null +++ b/test/WireMock.Net.Tests/WebSockets/WireMockServerWebSocketIntegrationTests.cs @@ -0,0 +1,571 @@ +// Copyright © WireMock.Net + +using System.Net.WebSockets; +using AwesomeAssertions; +using WireMock.Net.Xunit; +using WireMock.RequestBuilders; +using WireMock.ResponseBuilders; +using WireMock.Server; +using WireMock.Settings; + +namespace WireMock.Net.Tests.WebSockets; + +public class WireMockServerWebSocketIntegrationTests(ITestOutputHelper output, ITestContextAccessor testContext) +{ + private readonly CancellationToken _ct = testContext.Current.CancellationToken; + + [Fact] + public async Task GetWebSocketConnections_Should_Return_All_Active_Connections() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/test") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithEcho() + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + using var client3 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/test"); + + // Act + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + await client3.ConnectAsync(uri, _ct); + + // Assert + var connections = server.GetWebSocketConnections(); + connections.Should().HaveCount(3); + connections.Should().AllSatisfy(c => c.Should().NotBeNull()); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client3.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(300, _ct); + } + + [Fact] + public async Task GetWebSocketConnections_Should_Return_Empty_When_No_Connections() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/test") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithEcho() + ) + ); + + // Act + var connections = server.GetWebSocketConnections(); + + // Assert + connections.Should().BeEmpty(); + } + + [Fact] + public async Task GetWebSocketConnections_Should_Return_Connections_For_Specific_Mapping() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var mapping1Guid = Guid.NewGuid(); + var mapping2Guid = Guid.NewGuid(); + + server + .Given(Request.Create() + .WithPath("/ws/echo1") + .WithWebSocketUpgrade() + ) + .WithGuid(mapping1Guid) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithEcho() + ) + ); + + server + .Given(Request.Create() + .WithPath("/ws/echo2") + .WithWebSocketUpgrade() + ) + .WithGuid(mapping2Guid) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithEcho() + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + using var client3 = new ClientWebSocket(); + + var uri1 = new Uri($"{server.Url}/ws/echo1"); + var uri2 = new Uri($"{server.Url}/ws/echo2"); + + // Act + await client1.ConnectAsync(uri1, _ct); + await client2.ConnectAsync(uri1, _ct); + await client3.ConnectAsync(uri2, _ct); + + // Assert + var allConnections = server.GetWebSocketConnections(); + allConnections.Should().HaveCount(3); + + var mapping1Connections = server.GetWebSocketConnections(mapping1Guid); + mapping1Connections.Should().HaveCount(2); + + var mapping2Connections = server.GetWebSocketConnections(mapping2Guid); + mapping2Connections.Should().HaveCount(1); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client3.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(300, _ct); + } + + [Fact] + public async Task AbortWebSocketConnectionAsync_Should_Close_Specific_Connection() + { + // Arrange + var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/test") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(30)) + .WithEcho() + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/test"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + var connections = server.GetWebSocketConnections(); + connections.Should().HaveCount(2); + + var connectionIdToAbort = connections.First().ConnectionId; + + // Act + await server.AbortWebSocketConnectionAsync(connectionIdToAbort, "Abort by test", _ct); + + // Assert + var remainingConnections = server.GetWebSocketConnections(); + remainingConnections.Should().HaveCount(1); + var remainingConnection = remainingConnections.First(); + remainingConnection.ConnectionId.Should().NotBe(connectionIdToAbort); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task BroadcastToWebSocketsAsync_Should_Broadcast_Text_To_Specific_Mapping() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastMessage = "Server broadcast message"; + var mappingGuid = Guid.NewGuid(); + + server + .Given(Request.Create() + .WithPath("/ws/broadcast") + .WithWebSocketUpgrade() + ) + .WithGuid(mappingGuid) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + if (text.StartsWith("ready")) + { + await context.SendAsync("ready!"); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + // Signal ready + await client1.SendAsync("ready", cancellationToken: _ct); + await client2.SendAsync("ready", cancellationToken: _ct); + + var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + text1.Should().Be("ready!"); + text2.Should().Be("ready!"); + + // Act + await server.BroadcastToWebSocketsAsync(mappingGuid, broadcastMessage, cancellationToken: _ct); + + // Assert + var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + received1.Should().Be(broadcastMessage); + received2.Should().Be(broadcastMessage); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task BroadcastToWebSocketsAsync_Should_Broadcast_Binary_To_Specific_Mapping() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastData = new byte[] { 0x01, 0x02, 0x03, 0x04 }; + var mappingGuid = Guid.NewGuid(); + + server + .Given(Request.Create() + .WithPath("/ws/broadcast-binary") + .WithWebSocketUpgrade() + ) + .WithGuid(mappingGuid) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + if (text.StartsWith("ready")) + { + await context.SendAsync("ready!"); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast-binary"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + // Signal ready + await client1.SendAsync("ready", cancellationToken: _ct); + await client2.SendAsync("ready", cancellationToken: _ct); + + var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + text1.Should().Be("ready!"); + text2.Should().Be("ready!"); + + // Act + await server.BroadcastToWebSocketsAsync(mappingGuid, broadcastData, cancellationToken: _ct); + + // Assert + var received1 = await client1.ReceiveAsBytesAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsBytesAsync(cancellationToken: _ct); + + received1.Should().BeEquivalentTo(broadcastData); + received2.Should().BeEquivalentTo(broadcastData); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task BroadcastToAllWebSocketsAsync_Should_Broadcast_Text_To_All_Mappings() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastMessage = "Broadcast to all mappings"; + + server + .Given(Request.Create() + .WithPath("/ws/mapping1") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + if (text.StartsWith("ready")) + { + await context.SendAsync("ready!"); + } + } + }) + ) + ); + + server + .Given(Request.Create() + .WithPath("/ws/mapping2") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + if (text.StartsWith("ready")) + { + await context.SendAsync("ready!"); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri1 = new Uri($"{server.Url}/ws/mapping1"); + var uri2 = new Uri($"{server.Url}/ws/mapping2"); + + await client1.ConnectAsync(uri1, _ct); + await client2.ConnectAsync(uri2, _ct); + + // Signal ready + await client1.SendAsync("ready", cancellationToken: _ct); + await client2.SendAsync("ready", cancellationToken: _ct); + + var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + text1.Should().Be("ready!"); + text2.Should().Be("ready!"); + + // Act + await server.BroadcastToAllWebSocketsAsync(broadcastMessage, cancellationToken: _ct); + + // Assert - both clients from different mappings should receive the broadcast + var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + received1.Should().Be(broadcastMessage); + received2.Should().Be(broadcastMessage); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task BroadcastToAllWebSocketsAsync_Should_Broadcast_Binary_To_All_Mappings() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastData = new byte[] { 0xAA, 0xBB, 0xCC }; + + server + .Given(Request.Create() + .WithPath("/ws/mapping1") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + if (text.StartsWith("ready")) + { + await context.SendAsync("ready!"); + } + } + }) + ) + ); + + server + .Given(Request.Create() + .WithPath("/ws/mapping2") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + if (text.StartsWith("ready")) + { + await context.SendAsync("ready!"); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri1 = new Uri($"{server.Url}/ws/mapping1"); + var uri2 = new Uri($"{server.Url}/ws/mapping2"); + + await client1.ConnectAsync(uri1, _ct); + await client2.ConnectAsync(uri2, _ct); + + // Signal ready + await client1.SendAsync("ready", cancellationToken: _ct); + await client2.SendAsync("ready", cancellationToken: _ct); + + var text1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var text2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + text1.Should().Be("ready!"); + text2.Should().Be("ready!"); + + // Act + await server.BroadcastToAllWebSocketsAsync(broadcastData, cancellationToken: _ct); + + // Assert + var received1 = await client1.ReceiveAsBytesAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsBytesAsync(cancellationToken: _ct); + + received1.Should().BeEquivalentTo(broadcastData); + received2.Should().BeEquivalentTo(broadcastData); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task GetWebSocketConnections_Should_Update_After_Client_Disconnect() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/test") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(5)) + .WithEcho() + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/test"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + var initialConnections = server.GetWebSocketConnections(); + initialConnections.Should().HaveCount(2); + + // Act + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnect", _ct); + await Task.Delay(100, _ct); + + // Assert + var remainingConnections = server.GetWebSocketConnections(); + remainingConnections.Should().HaveCount(1); + + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } +} \ No newline at end of file