TypeLoader: implement Try methods (#1358)

* TypeLoader: implement Try methods

* fix
This commit is contained in:
Stef Heyenrath
2025-08-31 08:48:29 +02:00
committed by GitHub
parent 5c5e104f2c
commit 371bfdc160
10 changed files with 190 additions and 81 deletions

View File

@@ -1,7 +1,6 @@
// Copyright © WireMock.Net
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
@@ -33,16 +32,25 @@ internal class MimeKitUtils : IMimeKitUtils
StartsWithMultiPart(contentTypeHeader)
)
{
var bytes = requestMessage.BodyData?.DetectedBodyType switch
byte[] bytes;
switch (requestMessage.BodyData?.DetectedBodyType)
{
// If the body is bytes, use the BodyAsBytes to match on.
BodyType.Bytes => requestMessage.BodyData.BodyAsBytes!,
case BodyType.Bytes:
bytes = requestMessage.BodyData.BodyAsBytes!;
break;
// If the body is a String or MultiPart, use the BodyAsString to match on.
BodyType.String or BodyType.MultiPart => Encoding.UTF8.GetBytes(requestMessage.BodyData.BodyAsString!),
case BodyType.String or BodyType.MultiPart:
bytes = Encoding.UTF8.GetBytes(requestMessage.BodyData.BodyAsString!);
break;
_ => throw new NotSupportedException()
};
// Else return false.
default:
mimeMessageData = null;
return false;
}
var fixedBytes = FixBytes(bytes, contentTypeHeader[0]);

View File

@@ -12,7 +12,7 @@ namespace WireMock.Matchers.Request;
/// </summary>
public class RequestMessageMultiPartMatcher : IRequestMatcher
{
private static readonly IMimeKitUtils MimeKitUtils = TypeLoader.LoadStaticInstance<IMimeKitUtils>();
private readonly IMimeKitUtils _mimeKitUtils = LoadMimeKitUtils();
/// <summary>
/// The matchers.
@@ -62,7 +62,7 @@ public class RequestMessageMultiPartMatcher : IRequestMatcher
return requestMatchResult.AddScore(GetType(), score, null);
}
if (!MimeKitUtils.TryGetMimeMessage(requestMessage, out var message))
if (!_mimeKitUtils.TryGetMimeMessage(requestMessage, out var message))
{
return requestMatchResult.AddScore(GetType(), score, null);
}
@@ -96,4 +96,14 @@ public class RequestMessageMultiPartMatcher : IRequestMatcher
return requestMatchResult.AddScore(GetType(), score, exception);
}
private static IMimeKitUtils LoadMimeKitUtils()
{
if (TypeLoader.TryLoadStaticInstance<IMimeKitUtils>(out var mimeKitUtils))
{
return mimeKitUtils;
}
throw new InvalidOperationException("MimeKit is required for RequestMessageMultiPartMatcher. Please install the WireMock.Net.MimePart package.");
}
}

View File

@@ -178,9 +178,12 @@ namespace WireMock.Owin.Mappers
return (bodyData.Encoding ?? _utf8NoBom).GetBytes(jsonBody);
case BodyType.ProtoBuf:
var protoDefinitions = bodyData.ProtoDefinition?.Invoke().Texts;
var protoBufUtils = TypeLoader.LoadStaticInstance<IProtoBufUtils>();
return await protoBufUtils.GetProtoBufMessageWithHeaderAsync(protoDefinitions, bodyData.ProtoBufMessageType, bodyData.BodyAsJson).ConfigureAwait(false);
if (TypeLoader.TryLoadStaticInstance<IProtoBufUtils>(out var protoBufUtils))
{
var protoDefinitions = bodyData.ProtoDefinition?.Invoke().Texts;
return await protoBufUtils.GetProtoBufMessageWithHeaderAsync(protoDefinitions, bodyData.ProtoBufMessageType, bodyData.BodyAsJson).ConfigureAwait(false);
}
break;
case BodyType.Bytes:
return bodyData.BodyAsBytes;

View File

@@ -183,16 +183,9 @@ public class RequestMessage : IRequestMessage
#endif
#if MIMEKIT
try
if (TypeLoader.TryLoadStaticInstance<IMimeKitUtils>(out var mimeKitUtils) && mimeKitUtils.TryGetMimeMessage(this, out var mimeMessage))
{
if (TypeLoader.LoadStaticInstance<IMimeKitUtils>().TryGetMimeMessage(this, out var mimeMessage))
{
BodyAsMimeMessage = mimeMessage;
}
}
catch
{
// Ignore exception from MimeMessage.Load
BodyAsMimeMessage = mimeMessage;
}
#endif
}

View File

@@ -55,7 +55,12 @@ internal class MatcherMapper
case "CSharpCodeMatcher":
if (_settings.AllowCSharpCodeMatcher == true)
{
return TypeLoader.LoadNewInstance<ICSharpCodeMatcher>(matchBehaviour, matchOperator, stringPatterns);
if (TypeLoader.TryLoadNewInstance<ICSharpCodeMatcher>(out var csharpCodeMatcher, matchBehaviour, matchOperator, stringPatterns))
{
return csharpCodeMatcher;
}
throw new InvalidOperationException("The 'CSharpCodeMatcher' cannot be loaded. Please install the WireMock.Net.Matchers.CSharpCode package.");
}
throw new NotSupportedException("It's not allowed to use the 'CSharpCodeMatcher' because WireMockServerSettings.AllowCSharpCodeMatcher is not set to 'true'.");
@@ -75,7 +80,12 @@ internal class MatcherMapper
case "GraphQLMatcher":
var patternAsString = stringPatterns[0].GetPattern();
var schema = new AnyOf<string, StringPattern, ISchemaData>(patternAsString);
return TypeLoader.LoadNewInstance<IGraphQLMatcher>(schema, matcherModel.CustomScalars, matchBehaviour, matchOperator);
if (TypeLoader.TryLoadNewInstance<IGraphQLMatcher>(out var graphQLMatcher, schema, matcherModel.CustomScalars, matchBehaviour, matchOperator))
{
return graphQLMatcher;
}
throw new InvalidOperationException("The 'GraphQLMatcher' cannot be loaded. Please install the WireMock.Net.GraphQL package.");
case "MimePartMatcher":
return CreateMimePartMatcher(matchBehaviour, matcherModel);
@@ -282,18 +292,34 @@ internal class MatcherMapper
var contentTransferEncodingMatcher = Map(matcher.ContentTransferEncodingMatcher) as IStringMatcher;
var contentMatcher = Map(matcher.ContentMatcher);
return TypeLoader.LoadNewInstance<IMimePartMatcher>(matchBehaviour, contentTypeMatcher, contentDispositionMatcher, contentTransferEncodingMatcher, contentMatcher);
if (TypeLoader.TryLoadNewInstance<IMimePartMatcher>(
out var mimePartMatcher,
matchBehaviour,
contentTypeMatcher,
contentDispositionMatcher,
contentTransferEncodingMatcher,
contentMatcher))
{
return mimePartMatcher;
}
throw new InvalidOperationException("The 'MimePartMatcher' cannot be loaded. Please install the WireMock.Net.MimePart package.");
}
private IProtoBufMatcher CreateProtoBufMatcher(MatchBehaviour? matchBehaviour, IReadOnlyList<string> protoDefinitions, MatcherModel matcher)
{
var objectMatcher = Map(matcher.ContentMatcher) as IObjectMatcher;
return TypeLoader.LoadNewInstance<IProtoBufMatcher>(
if (TypeLoader.TryLoadNewInstance<IProtoBufMatcher>(
out var protobufMatcher,
() => ProtoDefinitionUtils.GetIdOrTexts(_settings, protoDefinitions.ToArray()),
matcher.ProtoBufMessageType!,
matchBehaviour ?? MatchBehaviour.AcceptOnMatch,
objectMatcher
);
objectMatcher))
{
return protobufMatcher;
}
throw new InvalidOperationException("The 'ProtoBufMatcher' cannot be loaded. Please install the WireMock.Net.ProtoBuf package.");
}
}

View File

@@ -366,10 +366,8 @@ public partial class WireMockServer
}
else if (responseModel.BodyAsJson != null)
{
if (responseModel.ProtoBufMessageType != null)
if (responseModel.ProtoBufMessageType != null && TypeLoader.TryLoadStaticInstance<IProtoBufUtils>(out var protoBufUtils))
{
var protoBufUtils = TypeLoader.LoadStaticInstance<IProtoBufUtils>();
if (responseModel.ProtoDefinition != null)
{
responseBuilder = protoBufUtils.UpdateResponseBuilder(responseBuilder, responseModel.ProtoBufMessageType, responseModel.BodyAsJson, responseModel.ProtoDefinition);

View File

@@ -100,7 +100,10 @@ public class RequestMessageGraphQLMatcher : IRequestMatcher
IDictionary<string, Type>? customScalars
)
{
var graphQLMatcher = TypeLoader.LoadNewInstance<IGraphQLMatcher>(schema, customScalars, matchBehaviour, MatchOperator.Or);
return [graphQLMatcher];
if (TypeLoader.TryLoadNewInstance<IGraphQLMatcher>(out var graphQLMatcher, schema, customScalars, matchBehaviour, MatchOperator.Or))
{
return [graphQLMatcher];
}
return [];
}
}

View File

@@ -25,7 +25,10 @@ public class RequestMessageProtoBufMatcher : IRequestMatcher
/// <param name="matcher">The optional matcher to use to match the ProtoBuf as (json) object.</param>
public RequestMessageProtoBufMatcher(MatchBehaviour matchBehaviour, Func<IdOrTexts> protoDefinition, string messageType, IObjectMatcher? matcher = null)
{
Matcher = TypeLoader.LoadNewInstance<IProtoBufMatcher>(protoDefinition, messageType, matchBehaviour, matcher);
if (TypeLoader.TryLoadNewInstance<IProtoBufMatcher>(out var protoBufMatcher, protoDefinition, messageType, matchBehaviour, matcher))
{
Matcher = protoBufMatcher;
}
}
/// <inheritdoc />

View File

@@ -14,68 +14,130 @@ internal static class TypeLoader
{
private static readonly ConcurrentDictionary<string, Type> Assemblies = new();
private static readonly ConcurrentDictionary<Type, object> Instances = new();
private static readonly ConcurrentBag<(string FullName, Type Type)> InstancesWhichCannotBeFoundByFullName = [];
private static readonly ConcurrentBag<(string FullName, Type Type)> StaticInstancesWhichCannotBeFoundByFullName = [];
private static readonly ConcurrentBag<Type> InstancesWhichCannotBeFound = [];
private static readonly ConcurrentBag<Type> StaticInstancesWhichCannotBeFound = [];
public static TInterface LoadNewInstance<TInterface>(params object?[] args) where TInterface : class
public static bool TryLoadNewInstance<TInterface>([NotNullWhen(true)] out TInterface? instance, params object?[] args) where TInterface : class
{
var pluginType = GetPluginType<TInterface>();
var type = typeof(TInterface);
if (InstancesWhichCannotBeFound.Contains(type))
{
instance = null;
return false;
}
return (TInterface)Activator.CreateInstance(pluginType, args)!;
if (TryGetPluginType<TInterface>(out var pluginType))
{
instance = (TInterface)Activator.CreateInstance(pluginType, args)!;
return true;
}
InstancesWhichCannotBeFound.Add(type);
instance = null;
return false;
}
public static TInterface LoadStaticInstance<TInterface>(params object?[] args) where TInterface : class
public static bool TryLoadStaticInstance<TInterface>([NotNullWhen(true)] out TInterface? staticInstance, params object?[] args) where TInterface : class
{
var pluginType = GetPluginType<TInterface>();
var type = typeof(TInterface);
if (StaticInstancesWhichCannotBeFound.Contains(type))
{
staticInstance = null;
return false;
}
return (TInterface)Instances.GetOrAdd(pluginType, key => Activator.CreateInstance(key, args)!);
if (TryGetPluginType<TInterface>(out var pluginType))
{
staticInstance = (TInterface)Instances.GetOrAdd(pluginType, key => Activator.CreateInstance(key, args)!);
return true;
}
StaticInstancesWhichCannotBeFound.Add(type);
staticInstance = null;
return false;
}
public static TInterface LoadNewInstanceByFullName<TInterface>(string implementationTypeFullName, params object?[] args) where TInterface : class
public static bool TryLoadNewInstanceByFullName<TInterface>([NotNullWhen(true)] out TInterface? instance, string implementationTypeFullName, params object?[] args) where TInterface : class
{
Guard.NotNullOrEmpty(implementationTypeFullName);
var pluginType = GetPluginTypeByFullName<TInterface>(implementationTypeFullName);
var type = typeof(TInterface);
if (InstancesWhichCannotBeFoundByFullName.Contains((implementationTypeFullName, type)))
{
instance = null;
return false;
}
return (TInterface)Activator.CreateInstance(pluginType, args)!;
if (TryGetPluginTypeByFullName<TInterface>(implementationTypeFullName, out var pluginType))
{
instance = (TInterface)Activator.CreateInstance(pluginType, args)!;
return true;
}
InstancesWhichCannotBeFoundByFullName.Add((implementationTypeFullName, type));
instance = null;
return false;
}
public static TInterface LoadStaticInstanceByFullName<TInterface>(string implementationTypeFullName, params object?[] args) where TInterface : class
public static bool TryLoadStaticInstanceByFullName<TInterface>([NotNullWhen(true)] out TInterface? staticInstance, string implementationTypeFullName, params object?[] args) where TInterface : class
{
Guard.NotNullOrEmpty(implementationTypeFullName);
var pluginType = GetPluginTypeByFullName<TInterface>(implementationTypeFullName);
var type = typeof(TInterface);
if (StaticInstancesWhichCannotBeFoundByFullName.Contains((implementationTypeFullName, type)))
{
staticInstance = null;
return false;
}
return (TInterface)Instances.GetOrAdd(pluginType, key => Activator.CreateInstance(key, args)!);
if (TryGetPluginTypeByFullName<TInterface>(implementationTypeFullName, out var pluginType))
{
staticInstance = (TInterface)Instances.GetOrAdd(pluginType, key => Activator.CreateInstance(key, args)!);
return true;
}
StaticInstancesWhichCannotBeFoundByFullName.Add((implementationTypeFullName, type));
staticInstance = null;
return false;
}
private static Type GetPluginType<TInterface>() where TInterface : class
private static bool TryGetPluginType<TInterface>([NotNullWhen(true)] out Type? foundType) where TInterface : class
{
var key = typeof(TInterface).FullName!;
return Assemblies.GetOrAdd(key, _ =>
if (Assemblies.TryGetValue(key, out foundType))
{
if (TryFindTypeInDlls<TInterface>(null, out var foundType))
{
return foundType;
}
return true;
}
throw new DllNotFoundException($"No dll found which implements interface '{key}'.");
});
if (TryFindTypeInDlls<TInterface>(null, out foundType))
{
Assemblies.TryAdd(key, foundType);
return true;
}
return false;
}
private static Type GetPluginTypeByFullName<TInterface>(string implementationTypeFullName) where TInterface : class
private static bool TryGetPluginTypeByFullName<TInterface>(string implementationTypeFullName, [NotNullWhen(true)] out Type? foundType) where TInterface : class
{
var @interface = typeof(TInterface).FullName;
var key = $"{@interface}_{implementationTypeFullName}";
return Assemblies.GetOrAdd(key, _ =>
if (Assemblies.TryGetValue(key, out foundType))
{
if (TryFindTypeInDlls<TInterface>(implementationTypeFullName, out var foundType))
{
return foundType;
}
return true;
}
throw new DllNotFoundException($"No dll found which implements Interface '{@interface}' and has FullName '{implementationTypeFullName}'.");
});
if (TryFindTypeInDlls<TInterface>(implementationTypeFullName, out foundType))
{
Assemblies.TryAdd(key, foundType);
return true;
}
return false;
}
private static bool TryFindTypeInDlls<TInterface>(string? implementationTypeFullName, [NotNullWhen(true)] out Type? pluginType) where TInterface : class

View File

@@ -1,6 +1,5 @@
// Copyright © WireMock.Net
using System;
using System.IO;
using AnyOfTypes;
using FluentAssertions;
@@ -56,7 +55,7 @@ public class TypeLoaderTests
}
[Fact]
public void LoadNewInstance()
public void TryLoadNewInstance()
{
var current = Directory.GetCurrentDirectory();
try
@@ -65,10 +64,11 @@ public class TypeLoaderTests
// Act
AnyOf<string, StringPattern> pattern = "x";
var result = TypeLoader.LoadNewInstance<ICSharpCodeMatcher>(MatchBehaviour.AcceptOnMatch, MatchOperator.Or, pattern);
var result = TypeLoader.TryLoadNewInstance<ICSharpCodeMatcher>(out var instance, MatchBehaviour.AcceptOnMatch, MatchOperator.Or, pattern);
// Assert
result.Should().NotBeNull();
result.Should().BeTrue();
instance.Should().BeOfType<CSharpCodeMatcher>();
}
finally
{
@@ -77,63 +77,66 @@ public class TypeLoaderTests
}
[Fact]
public void LoadNewInstanceByFullName()
public void TryLoadNewInstanceByFullName()
{
// Act
var result = TypeLoader.LoadNewInstanceByFullName<IDummyInterfaceWithImplementation>(typeof(DummyClass).FullName!);
var result = TypeLoader.TryLoadNewInstanceByFullName<IDummyInterfaceWithImplementation>(out var instance, typeof(DummyClass).FullName!);
// Assert
result.Should().BeOfType<DummyClass>();
result.Should().BeTrue();
instance.Should().BeOfType<DummyClass>();
}
[Fact]
public void LoadStaticInstance_ShouldOnlyCreateInstanceOnce()
public void TryLoadStaticInstance_ShouldOnlyCreateInstanceOnce()
{
// Arrange
var counter = new Counter();
// Act
var result = TypeLoader.LoadStaticInstance<IDummyInterfaceWithImplementationUsedForStaticTest>(counter);
TypeLoader.LoadStaticInstance<IDummyInterfaceWithImplementationUsedForStaticTest>(counter);
var result = TypeLoader.TryLoadStaticInstance<IDummyInterfaceWithImplementationUsedForStaticTest>(out var staticInstance, counter);
TypeLoader.TryLoadStaticInstance(out staticInstance, counter);
// Assert
result.Should().BeOfType<DummyClass1UsedForStaticTest>();
result.Should().BeTrue();
staticInstance.Should().BeOfType<DummyClass1UsedForStaticTest>();
counter.Value.Should().Be(1);
}
[Fact]
public void LoadStaticInstanceByFullName_ShouldOnlyCreateInstanceOnce()
public void TryLoadStaticInstanceByFullName_ShouldOnlyCreateInstanceOnce()
{
// Arrange
var counter = new Counter();
var fullName = typeof(DummyClass2UsedForStaticTest).FullName!;
// Act
var result = TypeLoader.LoadStaticInstanceByFullName<IDummyInterfaceWithImplementationUsedForStaticTest>(fullName, counter);
TypeLoader.LoadStaticInstanceByFullName<IDummyInterfaceWithImplementationUsedForStaticTest>(fullName, counter);
var result = TypeLoader.TryLoadStaticInstanceByFullName<IDummyInterfaceWithImplementationUsedForStaticTest>(out var staticInstance, fullName, counter);
TypeLoader.TryLoadStaticInstanceByFullName(out staticInstance, fullName, counter);
// Assert
result.Should().BeOfType<DummyClass2UsedForStaticTest>();
result.Should().BeTrue();
staticInstance.Should().BeOfType<DummyClass2UsedForStaticTest>();
counter.Value.Should().Be(1);
}
[Fact]
public void LoadNewInstance_ButNoImplementationFoundForInterface_ThrowsException()
public void TryLoadNewInstance_ButNoImplementationFoundForInterface_ReturnsFalse()
{
// Act
Action a = () => TypeLoader.LoadNewInstance<IDummyInterfaceNoImplementation>();
var result = TypeLoader.TryLoadNewInstance<IDummyInterfaceNoImplementation>(out _);
// Assert
a.Should().Throw<DllNotFoundException>().WithMessage("No dll found which implements Interface 'WireMock.Net.Tests.Util.TypeLoaderTests+IDummyInterfaceNoImplementation'.");
result.Should().BeFalse();
}
[Fact]
public void LoadNewInstanceByFullName_ButNoImplementationFoundForInterface_ThrowsException()
public void TryLoadNewInstanceByFullName_ButNoImplementationFoundForInterface_ReturnsFalse()
{
// Act
Action a = () => TypeLoader.LoadNewInstanceByFullName<IDummyInterfaceWithImplementation>("xyz");
var result = TypeLoader.TryLoadNewInstanceByFullName<IDummyInterfaceWithImplementation>(out _, "xyz");
// Assert
a.Should().Throw<DllNotFoundException>().WithMessage("No dll found which implements Interface 'WireMock.Net.Tests.Util.TypeLoaderTests+IDummyInterfaceWithImplementation' and has FullName 'xyz'.");
result.Should().BeFalse();
}
}