// Copyright © WireMock.Net using System; using System.Collections.Generic; using System.Linq; using Stef.Validation; using WireMock.Matchers; using WireMock.Matchers.Request; using WireMock.Types; namespace WireMock.WebSockets.Matchers; /// /// Matcher for WebSocket upgrade requests. /// public class WebSocketRequestMatcher : IRequestMatcher { private static string Name => nameof(WebSocketRequestMatcher); private readonly IStringMatcher? _pathMatcher; private readonly IList? _subProtocols; private readonly Func? _customPredicate; /// /// Initializes a new instance of the class. /// /// The optional path matcher. /// The optional list of acceptable subprotocols. /// The optional custom predicate for matching. public WebSocketRequestMatcher(IStringMatcher? pathMatcher = null, IList? subProtocols = null, Func? customPredicate = null) { _pathMatcher = pathMatcher; _subProtocols = subProtocols; _customPredicate = customPredicate; } /// public double GetMatchingScore(IRequestMessage requestMessage, IRequestMatchResult requestMatchResult) { var (score, exception) = GetMatchResult(requestMessage).Expand(); return requestMatchResult.AddScore(GetType(), score, exception); } private MatchResult GetMatchResult(IRequestMessage requestMessage) { Guard.NotNull(requestMessage); // Check if this is a WebSocket upgrade request if (!IsWebSocketUpgradeRequest(requestMessage)) { return MatchResult.From(Name); } var matchScore = MatchScores.Perfect; // Match path if matcher is provided if (_pathMatcher != null) { var pathMatchResult = _pathMatcher.IsMatch(requestMessage.Path ?? string.Empty); if (pathMatchResult.Score < 1.0) { return MatchResult.From(Name); } matchScore *= pathMatchResult.Score; } // Check subprotocol if specified if (_subProtocols?.Count > 0) { var requestSubProtocols = GetRequestedSubProtocols(requestMessage); var hasValidSubProtocol = requestSubProtocols.Any(_subProtocols.Contains); if (!hasValidSubProtocol && _subProtocols.Count > 0) { return MatchResult.From(Name); } } // Apply custom predicate if provided if (_customPredicate != null) { var wsRequest = CreateWebSocketConnectRequest(requestMessage); if (!_customPredicate(wsRequest)) { return MatchResult.From(Name); } } return MatchResult.From(Name, matchScore); } private static bool IsWebSocketUpgradeRequest(IRequestMessage request) { if (request.Headers == null) { return false; } var hasUpgradeHeader = request.Headers.TryGetValue("Upgrade", out var upgradeValues) && upgradeValues?.Any(v => v.Equals("websocket", StringComparison.OrdinalIgnoreCase)) == true; var hasConnectionHeader = request.Headers.TryGetValue("Connection", out var connectionValues) && connectionValues?.Any(v => v.IndexOf("Upgrade", StringComparison.OrdinalIgnoreCase) >= 0) == true; return hasUpgradeHeader && hasConnectionHeader; } private static string[] GetRequestedSubProtocols(IRequestMessage request) { if (request.Headers?.TryGetValue("Sec-WebSocket-Protocol", out var values) == true && values != null) { return values .SelectMany(v => v.Split(',')) .Select(s => s.Trim()) .ToArray(); } return []; } private static WebSocketConnectRequest CreateWebSocketConnectRequest(IRequestMessage request) { var headers = request.Headers ?? new Dictionary>(); var subProtocols = GetRequestedSubProtocols(request); var clientIP = request.ClientIP ?? string.Empty; return new WebSocketConnectRequest { Path = request.Path, Headers = headers, SubProtocols = subProtocols, RemoteAddress = clientIP }; } }