Add WebSockets (#1423)

* Add WebSockets

* Add tests

* fix

* more tests

* Add tests

* ...

* remove IOwin

* -

* tests

* fluent

* ok

* match

* .

* byte[]

* x

* func

* func

* byte

* trans

* ...

* frameworks.........

* jmes

* xxx

* sc
This commit is contained in:
Stef Heyenrath
2026-02-14 08:42:40 +01:00
committed by GitHub
parent dff55e175b
commit 8b27da95a8
103 changed files with 72659 additions and 398 deletions
@@ -0,0 +1,291 @@
// Copyright © WireMock.Net
using System;
using System.Net.WebSockets;
using Stef.Validation;
using WireMock.Matchers;
using WireMock.ResponseBuilders;
using WireMock.Settings;
using WireMock.Transformers;
namespace WireMock.WebSockets;
internal class WebSocketBuilder(Response response) : IWebSocketBuilder
{
private readonly List<(IMatcher matcher, List<WebSocketMessageBuilder> messages)> _conditionalMessages = [];
/// <inheritdoc />
public string? AcceptProtocol { get; private set; }
/// <inheritdoc />
public bool IsEcho { get; private set; }
/// <inheritdoc />
public bool IsBroadcast { get; private set; }
/// <inheritdoc />
public Func<WebSocketMessage, IWebSocketContext, Task>? MessageHandler { get; private set; }
/// <inheritdoc />
public ProxyAndRecordSettings? ProxySettings { get; private set; }
/// <inheritdoc />
public TimeSpan? CloseTimeout { get; private set; }
/// <inheritdoc />
public int? MaxMessageSize { get; private set; }
/// <inheritdoc />
public int? ReceiveBufferSize { get; private set; }
/// <inheritdoc />
public TimeSpan? KeepAliveIntervalSeconds { get; private set; }
/// <inheritdoc />
public IWebSocketBuilder WithAcceptProtocol(string protocol)
{
AcceptProtocol = Guard.NotNull(protocol);
return this;
}
public IWebSocketBuilder WithEcho()
{
IsEcho = true;
return this;
}
public IWebSocketBuilder SendMessage(Action<IWebSocketMessageBuilder> configure)
{
Guard.NotNull(configure);
var messageBuilder = new WebSocketMessageBuilder();
configure(messageBuilder);
return WithMessageHandler(async (message, context) =>
{
if (messageBuilder.Delay.HasValue)
{
await Task.Delay(messageBuilder.Delay.Value);
}
await SendMessageAsync(context, messageBuilder, message);
});
}
public IWebSocketBuilder SendMessages(Action<IWebSocketMessagesBuilder> configure)
{
Guard.NotNull(configure);
var messagesBuilder = new WebSocketMessagesBuilder();
configure(messagesBuilder);
return WithMessageHandler(async (message, context) =>
{
foreach (var messageBuilder in messagesBuilder.Messages)
{
if (messageBuilder.Delay.HasValue)
{
await Task.Delay(messageBuilder.Delay.Value);
}
await SendMessageAsync(context, messageBuilder, message);
}
});
}
public IWebSocketMessageConditionBuilder WhenMessage(string wildcardPattern)
{
Guard.NotNull(wildcardPattern);
var matcher = new WildcardMatcher(MatchBehaviour.AcceptOnMatch, wildcardPattern);
return new WebSocketMessageConditionBuilder(this, matcher);
}
public IWebSocketMessageConditionBuilder WhenMessage(byte[] exactPattern)
{
Guard.NotNull(exactPattern);
var matcher = new ExactObjectMatcher(MatchBehaviour.AcceptOnMatch, exactPattern);
return new WebSocketMessageConditionBuilder(this, matcher);
}
public IWebSocketMessageConditionBuilder WhenMessage(IMatcher matcher)
{
Guard.NotNull(matcher);
return new WebSocketMessageConditionBuilder(this, matcher);
}
public IWebSocketBuilder WithMessageHandler(Func<WebSocketMessage, IWebSocketContext, Task> handler)
{
MessageHandler = Guard.NotNull(handler);
IsEcho = false;
return this;
}
public IWebSocketBuilder WithBroadcast()
{
IsBroadcast = true;
return this;
}
public IWebSocketBuilder WithProxy(ProxyAndRecordSettings settings)
{
ProxySettings = Guard.NotNull(settings);
IsEcho = false;
return this;
}
public IWebSocketBuilder WithCloseTimeout(TimeSpan timeout)
{
CloseTimeout = timeout;
return this;
}
public IWebSocketBuilder WithMaxMessageSize(int sizeInBytes)
{
MaxMessageSize = Guard.Condition(sizeInBytes, s => s > 0);
return this;
}
public IWebSocketBuilder WithReceiveBufferSize(int sizeInBytes)
{
ReceiveBufferSize = Guard.Condition(sizeInBytes, s => s > 0);
return this;
}
public IWebSocketBuilder WithKeepAliveInterval(TimeSpan interval)
{
KeepAliveIntervalSeconds = interval;
return this;
}
internal IWebSocketBuilder AddConditionalMessage(IMatcher matcher, WebSocketMessageBuilder messageBuilder)
{
_conditionalMessages.Add((matcher, new List<WebSocketMessageBuilder> { messageBuilder }));
SetupConditionalHandler();
return this;
}
internal IWebSocketBuilder AddConditionalMessages(IMatcher matcher, List<WebSocketMessageBuilder> messages)
{
_conditionalMessages.Add((matcher, messages));
SetupConditionalHandler();
return this;
}
private void SetupConditionalHandler()
{
if (_conditionalMessages.Count == 0)
{
return;
}
WithMessageHandler(async (message, context) =>
{
// Check each condition in order
foreach (var (matcher, messages) in _conditionalMessages)
{
// Try to match the message
if (await MatchMessageAsync(message, matcher))
{
// Execute the corresponding messages
foreach (var messageBuilder in messages)
{
if (messageBuilder.Delay.HasValue)
{
await Task.Delay(messageBuilder.Delay.Value);
}
await SendMessageAsync(context, messageBuilder, message);
// If this message should close the connection, do it after sending
if (messageBuilder.ShouldClose)
{
try
{
await Task.Delay(100); // Small delay to ensure message is sent
await context.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by handler");
}
catch
{
// Ignore errors during close
}
}
}
return; // Stop after first match
}
}
});
}
private async Task SendMessageAsync(IWebSocketContext context, WebSocketMessageBuilder messageBuilder, WebSocketMessage incomingMessage)
{
switch (messageBuilder.Type)
{
case WebSocketMessageType.Text:
var text = messageBuilder.MessageText!;
if (response.UseTransformer)
{
text = ApplyTransformer(context, incomingMessage, text);
}
await context.SendAsync(text);
break;
case WebSocketMessageType.Binary:
await context.SendAsync(messageBuilder.MessageBytes!);
break;
}
}
private string ApplyTransformer(IWebSocketContext context, WebSocketMessage incomingMessage, string text)
{
try
{
if (incomingMessage == null)
{
// No incoming message, can't apply transformer
return text;
}
var transformer = TransformerFactory.Create(response.TransformerType, context.Mapping.Settings);
var model = new WebSocketTransformModel
{
Mapping = context.Mapping,
Request = context.RequestMessage,
Message = incomingMessage,
Data = incomingMessage.MessageType == WebSocketMessageType.Text ? incomingMessage.Text : null
};
return transformer.Transform(text, model);
}
catch
{
// If transformation fails, return original text
return text;
}
}
private static async Task<bool> MatchMessageAsync(WebSocketMessage message, IMatcher matcher)
{
if (message.MessageType == WebSocketMessageType.Text)
{
if (matcher is IStringMatcher stringMatcher)
{
var result = stringMatcher.IsMatch(message.Text);
return result.IsPerfect();
}
if (matcher is IFuncMatcher funcMatcher)
{
var result = funcMatcher.IsMatch(message.Text);
return result.IsPerfect();
}
}
if (message.MessageType == WebSocketMessageType.Binary && matcher is IBytesMatcher bytesMatcher && message.Bytes != null)
{
var result = await bytesMatcher.IsMatchAsync(message.Bytes);
return result.IsPerfect();
}
return false;
}
}
@@ -0,0 +1,60 @@
// Copyright © WireMock.Net
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Net.WebSockets;
namespace WireMock.WebSockets;
/// <summary>
/// Registry for managing WebSocket connections per mapping
/// </summary>
internal class WebSocketConnectionRegistry
{
private readonly ConcurrentDictionary<Guid, WireMockWebSocketContext> _connections = new();
/// <summary>
/// Add a connection to the registry
/// </summary>
public void AddConnection(WireMockWebSocketContext context)
{
_connections.TryAdd(context.ConnectionId, context);
}
/// <summary>
/// Remove a connection from the registry
/// </summary>
public void RemoveConnection(Guid connectionId)
{
_connections.TryRemove(connectionId, out _);
}
/// <summary>
/// Get all connections
/// </summary>
public IReadOnlyCollection<WireMockWebSocketContext> GetConnections()
{
return _connections.Values.ToList();
}
/// <summary>
/// Try to get a specific connection
/// </summary>
public bool TryGetConnection(Guid connectionId, [NotNullWhen(true)] out WireMockWebSocketContext? connection)
{
return _connections.TryGetValue(connectionId, out connection);
}
/// <summary>
/// Broadcast text to all connections
/// </summary>
public async Task BroadcastTextAsync(string text, CancellationToken cancellationToken = default)
{
var tasks = _connections.Values
.Where(c => c.WebSocket.State == WebSocketState.Open)
.Select(c => c.SendAsync(text, cancellationToken));
await Task.WhenAll(tasks);
}
}
@@ -0,0 +1,56 @@
// Copyright © WireMock.Net
using System.Net.WebSockets;
using Stef.Validation;
namespace WireMock.WebSockets;
internal class WebSocketMessageBuilder : IWebSocketMessageBuilder
{
public string? MessageText { get; private set; }
public byte[]? MessageBytes { get; private set; }
public object? MessageData { get; private set; }
public TimeSpan? Delay { get; private set; }
public WebSocketMessageType Type { get; private set; }
public bool ShouldClose { get; private set; }
public IWebSocketMessageBuilder WithText(string text)
{
MessageText = Guard.NotNull(text);
Type = WebSocketMessageType.Text;
return this;
}
public IWebSocketMessageBuilder WithBinary(byte[] bytes)
{
MessageBytes = Guard.NotNull(bytes);
Type = WebSocketMessageType.Binary;
return this;
}
public IWebSocketMessageBuilder WithDelay(TimeSpan delay)
{
Delay = delay;
return this;
}
public IWebSocketMessageBuilder WithDelay(int delayInMilliseconds)
{
Guard.Condition(delayInMilliseconds, d => d >= 0, nameof(delayInMilliseconds));
Delay = TimeSpan.FromMilliseconds(delayInMilliseconds);
return this;
}
public IWebSocketMessageBuilder Close()
{
ShouldClose = true;
return this;
}
public IWebSocketMessageBuilder AndClose() => Close();
}
@@ -0,0 +1,36 @@
// Copyright © WireMock.Net
using WireMock.Matchers;
using Stef.Validation;
namespace WireMock.WebSockets;
internal class WebSocketMessageConditionBuilder : IWebSocketMessageConditionBuilder
{
private readonly WebSocketBuilder _parent;
private readonly IMatcher _matcher;
public WebSocketMessageConditionBuilder(WebSocketBuilder parent, IMatcher matcher)
{
_parent = Guard.NotNull(parent);
_matcher = Guard.NotNull(matcher);
}
public IWebSocketBuilder SendMessage(Action<IWebSocketMessageBuilder> configure)
{
Guard.NotNull(configure);
var messageBuilder = new WebSocketMessageBuilder();
configure(messageBuilder);
return _parent.AddConditionalMessage(_matcher, messageBuilder);
}
public IWebSocketBuilder SendMessages(Action<IWebSocketMessagesBuilder> configure)
{
Guard.NotNull(configure);
var messagesBuilder = new WebSocketMessagesBuilder();
configure(messagesBuilder);
return _parent.AddConditionalMessages(_matcher, messagesBuilder.Messages);
}
}
@@ -0,0 +1,16 @@
// Copyright © WireMock.Net
namespace WireMock.WebSockets;
internal class WebSocketMessagesBuilder : IWebSocketMessagesBuilder
{
internal List<WebSocketMessageBuilder> Messages { get; } = [];
public IWebSocketMessagesBuilder AddMessage(Action<IWebSocketMessageBuilder> configure)
{
var messageBuilder = new WebSocketMessageBuilder();
configure(messageBuilder);
Messages.Add(messageBuilder);
return this;
}
}
@@ -0,0 +1,29 @@
// Copyright © WireMock.Net
namespace WireMock.WebSockets;
/// <summary>
/// Model for WebSocket message transformation
/// </summary>
internal struct WebSocketTransformModel
{
/// <summary>
/// The mapping that matched this WebSocket request
/// </summary>
public IMapping Mapping { get; set; }
/// <summary>
/// The original request that initiated the WebSocket connection
/// </summary>
public IRequestMessage Request { get; set; }
/// <summary>
/// The incoming WebSocket message
/// </summary>
public WebSocketMessage Message { get; set; }
/// <summary>
/// The message data as string
/// </summary>
public string? Data { get; set; }
}
@@ -0,0 +1,178 @@
// Copyright © WireMock.Net
using System.Net.WebSockets;
using System.Text;
using Microsoft.AspNetCore.Http;
using Newtonsoft.Json;
using Stef.Validation;
using WireMock.Extensions;
using WireMock.Owin;
namespace WireMock.WebSockets;
/// <summary>
/// WebSocket context implementation
/// </summary>
public class WireMockWebSocketContext : IWebSocketContext
{
private readonly IWireMockMiddlewareOptions _options;
/// <inheritdoc />
public Guid ConnectionId { get; } = Guid.NewGuid();
/// <inheritdoc />
public HttpContext HttpContext { get; }
/// <inheritdoc />
public WebSocket WebSocket { get; }
/// <inheritdoc />
public IRequestMessage RequestMessage { get; }
/// <inheritdoc />
public IMapping Mapping { get; }
internal WebSocketConnectionRegistry? Registry { get; }
internal WebSocketBuilder Builder { get; }
/// <summary>
/// Creates a new WebSocketContext
/// </summary>
internal WireMockWebSocketContext(
HttpContext httpContext,
WebSocket webSocket,
IRequestMessage requestMessage,
IMapping mapping,
WebSocketConnectionRegistry? registry,
WebSocketBuilder builder)
{
HttpContext = Guard.NotNull(httpContext);
WebSocket = Guard.NotNull(webSocket);
RequestMessage = Guard.NotNull(requestMessage);
Mapping = Guard.NotNull(mapping);
Registry = registry;
Builder = Guard.NotNull(builder);
// Get options from HttpContext
if (httpContext.Items.TryGetValue<IWireMockMiddlewareOptions>(nameof(WireMockMiddlewareOptions), out var options))
{
_options = options;
}
else
{
throw new InvalidOperationException("WireMockMiddlewareOptions not found in HttpContext.Items");
}
}
/// <inheritdoc />
public Task SendAsync(string text, CancellationToken cancellationToken = default)
{
var bytes = Encoding.UTF8.GetBytes(text);
return WebSocket.SendAsync(
new ArraySegment<byte>(bytes),
WebSocketMessageType.Text,
true,
cancellationToken
);
}
/// <inheritdoc />
public Task SendAsync(byte[] bytes, CancellationToken cancellationToken = default)
{
return WebSocket.SendAsync(
new ArraySegment<byte>(bytes),
WebSocketMessageType.Binary,
true,
cancellationToken
);
}
/// <inheritdoc />
public Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription)
{
return WebSocket.CloseAsync(closeStatus, statusDescription, CancellationToken.None);
}
/// <inheritdoc />
public void SetScenarioState(string nextState)
{
SetScenarioState(nextState, null);
}
/// <inheritdoc />
public void SetScenarioState(string nextState, string? description)
{
if (Mapping.Scenario == null)
{
return;
}
// Use the same logic as WireMockMiddleware
if (_options.Scenarios.TryGetValue(Mapping.Scenario, out var scenarioState))
{
// Directly set the next state (bypass counter logic for manual WebSocket state changes)
scenarioState.NextState = nextState;
scenarioState.Started = true;
scenarioState.Finished = nextState == null;
// Reset counter when manually setting state
scenarioState.Counter = 0;
}
else
{
// Create new scenario state if it doesn't exist
_options.Scenarios.TryAdd(Mapping.Scenario, new ScenarioState
{
Name = Mapping.Scenario,
NextState = nextState,
Started = true,
Finished = nextState == null,
Counter = 0
});
}
}
/// <summary>
/// Update scenario state following the same pattern as WireMockMiddleware.UpdateScenarioState
/// This is called automatically when the WebSocket connection is established.
/// </summary>
internal void UpdateScenarioState()
{
if (Mapping.Scenario == null)
{
return;
}
// Ensure scenario exists
if (!_options.Scenarios.TryGetValue(Mapping.Scenario, out var scenario))
{
return;
}
// Follow exact same logic as WireMockMiddleware.UpdateScenarioState
// Increase the number of times this state has been executed
scenario.Counter++;
// Only if the number of times this state is executed equals the required StateTimes,
// proceed to next state and reset the counter to 0
if (scenario.Counter == (Mapping.TimesInSameState ?? 1))
{
scenario.NextState = Mapping.NextState;
scenario.Counter = 0;
}
// Else just update Started and Finished
scenario.Started = true;
scenario.Finished = Mapping.NextState == null;
}
/// <inheritdoc />
public async Task BroadcastTextAsync(string text, CancellationToken cancellationToken = default)
{
if (Registry != null)
{
await Registry.BroadcastTextAsync(text, cancellationToken);
}
}
}