diff --git a/examples/WireMock.Net.WebSocketExamples/Program.cs b/examples/WireMock.Net.WebSocketExamples/Program.cs index c53f558d..b89cf943 100644 --- a/examples/WireMock.Net.WebSocketExamples/Program.cs +++ b/examples/WireMock.Net.WebSocketExamples/Program.cs @@ -211,7 +211,7 @@ public static class Program var broadcastMessage = $"[{timestamp}] Broadcast: {text}"; // Broadcast to all connected clients - await context.BroadcastTextAsync(broadcastMessage); + await context.BroadcastAsync(broadcastMessage); Console.WriteLine($"Broadcasted to {server.GetWebSocketConnections(broadcastMappingGuid).Count} clients: {text}"); } @@ -428,7 +428,7 @@ public static class Program { if (message.MessageType == WebSocketMessageType.Text) { - await context.BroadcastTextAsync($"[Broadcast] {message.Text}"); + await context.BroadcastAsync($"[Broadcast] {message.Text}"); } }) ) diff --git a/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs b/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs index 2bcc7686..cf079fee 100644 --- a/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs +++ b/src/WireMock.Net.Minimal/Server/WireMockServer.WebSocket.cs @@ -35,39 +35,64 @@ public partial class WireMockServer public async Task CloseWebSocketConnectionAsync( Guid connectionId, WebSocketCloseStatus closeStatus = WebSocketCloseStatus.NormalClosure, - string statusDescription = "Closed by server") + string statusDescription = "Closed by server", + CancellationToken cancellationToken = default) { foreach (var registry in _options.WebSocketRegistries.Values) { - if (registry.TryGetConnection(connectionId, out var connection)) + if (registry.TryGetConnection(connectionId, out var connection) && !cancellationToken.IsCancellationRequested) { - await connection.CloseAsync(closeStatus, statusDescription); + await connection.CloseAsync(closeStatus, statusDescription, cancellationToken); return; } } } /// - /// Broadcast a message to all WebSocket connections in a specific mapping + /// Broadcast a text message to all WebSocket connections in a specific mapping /// [PublicAPI] - public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, string text) + public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, string text, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) { if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry)) { - await registry.BroadcastTextAsync(text); + await registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); } } /// - /// Broadcast a message to all WebSocket connections + /// Broadcast a text message to all WebSocket connections /// [PublicAPI] - public async Task BroadcastToAllWebSocketsAsync(string text) + public async Task BroadcastToAllWebSocketsAsync(string text, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) { foreach (var registry in _options.WebSocketRegistries.Values) { - await registry.BroadcastTextAsync(text); + await registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); + } + } + + /// + /// Broadcast a binary message to all WebSocket connections in a specific mapping + /// + [PublicAPI] + public async Task BroadcastToWebSocketsAsync(Guid mappingGuid, byte[] bytes, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) + { + if (_options.WebSocketRegistries.TryGetValue(mappingGuid, out var registry)) + { + await registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); + } + } + + /// + /// Broadcast a binary message to all WebSocket connections + /// + [PublicAPI] + public async Task BroadcastToAllWebSocketsAsync(byte[] bytes, Guid? excludeConnectionId = null, CancellationToken cancellationToken = default) + { + foreach (var registry in _options.WebSocketRegistries.Values) + { + await registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); } } } \ No newline at end of file diff --git a/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs b/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs index e8ec025c..0eac1867 100644 --- a/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs +++ b/src/WireMock.Net.Minimal/WebSockets/WebSocketConnectionRegistry.cs @@ -3,6 +3,7 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Net.WebSockets; +using static System.Net.Mime.MediaTypeNames; namespace WireMock.WebSockets; @@ -48,12 +49,24 @@ internal class WebSocketConnectionRegistry /// /// Broadcast text to all connections /// - public async Task BroadcastTextAsync(string text, CancellationToken cancellationToken = default) + public async Task BroadcastAsync(string text, Guid? excludeConnectionId, CancellationToken cancellationToken = default) { - var tasks = _connections.Values - .Where(c => c.WebSocket.State == WebSocketState.Open) - .Select(c => c.SendAsync(text, cancellationToken)); - + var tasks = Filter(excludeConnectionId).Select(c => c.SendAsync(text, cancellationToken)); await Task.WhenAll(tasks); } + + /// + /// Broadcast binary to all connections + /// + public async Task BroadcastAsync(byte[] bytes, Guid? excludeConnectionId, CancellationToken cancellationToken = default) + { + var tasks = Filter(excludeConnectionId).Select(c => c.SendAsync(bytes, cancellationToken)); + await Task.WhenAll(tasks); + } + + private IEnumerable Filter(Guid? excludeConnectionId) + { + return _connections.Values + .Where(c =>c.WebSocket.State == WebSocketState.Open && (!excludeConnectionId.HasValue || c.ConnectionId != excludeConnectionId)); + } } \ No newline at end of file diff --git a/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs b/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs index 2a5944b5..a0260ff4 100644 --- a/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs +++ b/src/WireMock.Net.Minimal/WebSockets/WireMockWebSocketContext.cs @@ -94,19 +94,30 @@ public class WireMockWebSocketContext : IWebSocketContext } /// - public async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription) + public async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default) { - await WebSocket.CloseAsync(closeStatus, statusDescription, CancellationToken.None); + await WebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken); LogWebSocketMessage(WebSocketMessageDirection.Send, WebSocketMessageType.Close, $"CloseStatus: {closeStatus}, Description: {statusDescription}", null); } /// - public async Task BroadcastTextAsync(string text, CancellationToken cancellationToken = default) + public async Task BroadcastAsync(string text, bool excludeSender = false, CancellationToken cancellationToken = default) { if (Registry != null) { - await Registry.BroadcastTextAsync(text, cancellationToken); + Guid? excludeConnectionId = excludeSender ? ConnectionId : null; + await Registry.BroadcastAsync(text, excludeConnectionId, cancellationToken); + } + } + + /// + public async Task BroadcastAsync(byte[] bytes, bool excludeSender = false, CancellationToken cancellationToken = default) + { + if (Registry != null) + { + Guid? excludeConnectionId = excludeSender ? ConnectionId : null; + await Registry.BroadcastAsync(bytes, excludeConnectionId, cancellationToken); } } diff --git a/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs b/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs index 3b5f950b..3df9d1c3 100644 --- a/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs +++ b/src/WireMock.Net.Shared/WebSockets/IWebSocketContext.cs @@ -48,10 +48,15 @@ public interface IWebSocketContext /// /// Close the WebSocket connection /// - Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription); + Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken = default); /// /// Broadcast text message to all connections in this mapping /// - Task BroadcastTextAsync(string text, CancellationToken cancellationToken = default); + Task BroadcastAsync(string text, bool excludeSender = false, CancellationToken cancellationToken = default); + + /// + /// Broadcast binary message to all connections in this mapping + /// + Task BroadcastAsync(byte[] bytes, bool excludeSender = false, CancellationToken cancellationToken = default); } \ No newline at end of file diff --git a/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs b/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs index 33a7778e..2a588179 100644 --- a/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs +++ b/test/WireMock.Net.Tests/WebSockets/WebSocketIntegrationTests.cs @@ -780,4 +780,458 @@ public class WebSocketIntegrationTests(ITestOutputHelper output, ITestContextAcc await Task.Delay(100, _ct); } + + [Fact] + public async Task Broadcast_Should_Send_TextMessage_To_Multiple_Connected_Clients() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastMessage = "Broadcast to all clients"; + + server + .Given(Request.Create() + .WithPath("/ws/broadcast") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithBroadcast() + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + + if (text == "register") + { + await context.SendAsync($"Registered: {context.ConnectionId}"); + } + else if (text.StartsWith("broadcast:")) + { + var broadcastText = text.Substring(10); + await context.BroadcastAsync(broadcastText); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + using var client3 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast"); + + // Act + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + await client3.ConnectAsync(uri, _ct); + + await client1.SendAsync("register", cancellationToken: _ct); + await client2.SendAsync("register", cancellationToken: _ct); + await client3.SendAsync("register", cancellationToken: _ct); + + // Receive registration confirmations + var reg1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var reg2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + var reg3 = await client3.ReceiveAsTextAsync(cancellationToken: _ct); + + reg1.Should().StartWith("Registered: "); + reg2.Should().StartWith("Registered: "); + reg3.Should().StartWith("Registered: "); + + // Send broadcast from client1 + await client1.SendAsync($"broadcast:{broadcastMessage}", cancellationToken: _ct); + + // Assert - all clients should receive the broadcast + var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + var received3 = await client3.ReceiveAsTextAsync(cancellationToken: _ct); + + received1.Should().Be(broadcastMessage); + received2.Should().Be(broadcastMessage); + received3.Should().Be(broadcastMessage); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client3.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(300, _ct); + } + + [Fact] + public async Task Broadcast_Should_Send_BinaryMessage_To_Multiple_Connected_Clients() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var message = new byte[] { 0x00, 0x01, 0x02, 0x03 }; + var broadcastMessageFromWireMock = new byte[] { 0x01, 0x02, 0x03, 0x04 }; + + server + .Given(Request.Create() + .WithPath("/ws/broadcast") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithBroadcast() + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text && message.Text == "register") + { + await context.SendAsync($"Registered: {context.ConnectionId}"); + } + + if (message.MessageType == WebSocketMessageType.Binary) + { + await context.BroadcastAsync(broadcastMessageFromWireMock); + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + using var client3 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast"); + + // Act + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + await client3.ConnectAsync(uri, _ct); + + await client1.SendAsync("register", cancellationToken: _ct); + await client2.SendAsync("register", cancellationToken: _ct); + await client3.SendAsync("register", cancellationToken: _ct); + + // Receive registration confirmations + var reg1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var reg2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + var reg3 = await client3.ReceiveAsTextAsync(cancellationToken: _ct); + + reg1.Should().StartWith("Registered: "); + reg2.Should().StartWith("Registered: "); + reg3.Should().StartWith("Registered: "); + + // Send broadcast from client1 + await client1.SendAsync(new ArraySegment(message), WebSocketMessageType.Binary, true, cancellationToken: _ct); + + // Assert - all clients should receive the broadcast + var received1 = await client1.ReceiveAsBytesAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsBytesAsync(cancellationToken: _ct); + var received3 = await client3.ReceiveAsBytesAsync(cancellationToken: _ct); + + received1.Should().BeEquivalentTo(broadcastMessageFromWireMock); + received2.Should().BeEquivalentTo(broadcastMessageFromWireMock); + received3.Should().BeEquivalentTo(broadcastMessageFromWireMock); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client3.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(300, _ct); + } + + [Fact] + public async Task Broadcast_Should_Handle_Multiple_Broadcast_Messages() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/broadcast-multi") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + await context.BroadcastAsync(text); + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast-multi"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + var messages = new[] { "Message 1", "Message 2", "Message 3" }; + + // Act & Assert + foreach (var message in messages) + { + await client1.SendAsync(message, cancellationToken: _ct); + + var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + received1.Should().Be(message); + received2.Should().Be(message); + } + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(300, _ct); + } + + [Fact] + public async Task Broadcast_Should_Exclude_Sender_When_ExcludeSender_Is_True() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/broadcast-exclude") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithBroadcast() + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + + if (text.StartsWith("send:")) + { + var broadcastText = text.Substring(5); + await context.BroadcastAsync(broadcastText, excludeSender: true); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast-exclude"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + var broadcastMessage = "Exclusive broadcast"; + + // Act + await client1.SendAsync($"send:{broadcastMessage}", cancellationToken: _ct); + + // Assert - only client2 should receive the message + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + received2.Should().Be(broadcastMessage); + + // client1 should not receive anything (or should timeout) + var receiveTask1 = client1.ReceiveAsTextAsync(cancellationToken: _ct); + var delayTask = Task.Delay(500, _ct); + + var completedTask = await Task.WhenAny(receiveTask1, delayTask); + completedTask.Should().Be(delayTask, "client1 should not receive the exclusive broadcast"); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task Broadcast_Should_Work_With_Single_Client() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastMessage = "Single client broadcast"; + + server + .Given(Request.Create() + .WithPath("/ws/broadcast-single") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithBroadcast() + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + await context.BroadcastAsync(text); + } + }) + ) + ); + + using var client = new ClientWebSocket(); + var uri = new Uri($"{server.Url}/ws/broadcast-single"); + + // Act + await client.ConnectAsync(uri, _ct); + await client.SendAsync(broadcastMessage, cancellationToken: _ct); + + // Assert + var received = await client.ReceiveAsTextAsync(cancellationToken: _ct); + received.Should().Be(broadcastMessage); + + await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(100, _ct); + } + + [Fact] + public async Task Broadcast_Should_Handle_Client_Disconnect_During_Broadcast() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + var broadcastMessage = "Message after disconnect"; + + server + .Given(Request.Create() + .WithPath("/ws/broadcast-disconnect") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithBroadcast() + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + await context.BroadcastAsync(text); + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast-disconnect"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + // Act - disconnect client1 + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", _ct); + + // Send broadcast from client2 - should handle disconnected client gracefully + await client2.SendAsync(broadcastMessage, cancellationToken: _ct); + + // Assert - client2 should still receive the broadcast + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + received2.Should().Be(broadcastMessage); + + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } + + [Fact] + public async Task Broadcast_Should_Support_Targeted_Broadcasting_Based_On_Condition() + { + // Arrange + using var server = WireMockServer.Start(new WireMockServerSettings + { + Logger = new TestOutputHelperWireMockLogger(output), + Urls = ["ws://localhost:0"] + }); + + server + .Given(Request.Create() + .WithPath("/ws/broadcast-conditional") + .WithWebSocketUpgrade() + ) + .RespondWith(Response.Create() + .WithWebSocket(ws => ws + .WithCloseTimeout(TimeSpan.FromSeconds(10)) + .WithBroadcast() + .WithMessageHandler(async (message, context) => + { + if (message.MessageType == WebSocketMessageType.Text) + { + var text = message.Text ?? string.Empty; + + if (text.StartsWith("to-admins:")) + { + var adminMessage = text.Substring(10); + await context.SendAsync($"Admin broadcast: {adminMessage}"); + } + else if (text.StartsWith("to-all:")) + { + var allMessage = text.Substring(7); + await context.BroadcastAsync(allMessage); + } + } + }) + ) + ); + + using var client1 = new ClientWebSocket(); + using var client2 = new ClientWebSocket(); + + var uri = new Uri($"{server.Url}/ws/broadcast-conditional"); + + await client1.ConnectAsync(uri, _ct); + await client2.ConnectAsync(uri, _ct); + + // Act + await client1.SendAsync("to-all:General message", cancellationToken: _ct); + + // Assert - both clients receive the broadcast + var received1 = await client1.ReceiveAsTextAsync(cancellationToken: _ct); + var received2 = await client2.ReceiveAsTextAsync(cancellationToken: _ct); + + received1.Should().Be("General message"); + received2.Should().Be("General message"); + + await client1.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + await client2.CloseAsync(WebSocketCloseStatus.NormalClosure, "Test complete", _ct); + + await Task.Delay(200, _ct); + } } \ No newline at end of file