Skip to content

Commit 7163d0c

Browse files
committed
Support returning object from invoke and fix lifetime issues
1 parent 918a524 commit 7163d0c

9 files changed

Lines changed: 163 additions & 164 deletions

File tree

src/Foundatio.Mediator.Abstractions/Mediator.cs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@
33
using System.Text;
44
using Microsoft.Extensions.DependencyInjection;
55
using Microsoft.Extensions.Logging;
6+
using Microsoft.Extensions.Logging.Abstractions;
67

78
namespace Foundatio.Mediator;
89

910
public class Mediator : IMediator, IServiceProvider
1011
{
1112
private readonly IServiceProvider _serviceProvider;
1213
private readonly MediatorConfiguration _configuration;
14+
private readonly ILogger<Mediator> _logger;
1315

1416
[DebuggerStepThrough]
1517
public Mediator(IServiceProvider serviceProvider, MediatorConfiguration? configuration = null)
1618
{
1719
_serviceProvider = serviceProvider;
1820
_configuration = configuration ?? new MediatorConfiguration();
21+
_logger = _serviceProvider.GetService<ILogger<Mediator>>() ?? NullLogger<Mediator>.Instance;
1922
}
2023

2124
public IServiceProvider ServiceProvider => _serviceProvider;
@@ -36,14 +39,14 @@ public void Invoke(object message, CancellationToken cancellationToken = default
3639
public async ValueTask<TResponse> InvokeAsync<TResponse>(object message, CancellationToken cancellationToken = default)
3740
{
3841
var handlerFunc = GetInvokeAsyncResponseDelegate(message.GetType(), typeof(TResponse));
39-
var result = await handlerFunc(this, message, cancellationToken);
42+
object? result = await handlerFunc(this, message, cancellationToken);
4043
return (TResponse)result!;
4144
}
4245

4346
public TResponse Invoke<TResponse>(object message, CancellationToken cancellationToken = default)
4447
{
4548
var handlerFunc = GetInvokeResponseDelegate(message.GetType(), typeof(TResponse));
46-
var result = handlerFunc(this, message, cancellationToken);
49+
object? result = handlerFunc(this, message, cancellationToken);
4750
return (TResponse)result!;
4851
}
4952

@@ -55,11 +58,10 @@ public ValueTask PublishAsync(object message, CancellationToken cancellationToke
5558

5659
public void ShowRegisteredHandlers()
5760
{
58-
var logger = _serviceProvider.GetRequiredService<ILogger<Mediator>>();
5961
var registrations = _serviceProvider.GetServices<HandlerRegistration>().ToArray();
6062
if (!registrations.Any())
6163
{
62-
logger.LogInformation("No handlers registered.");
64+
_logger.LogInformation("No handlers registered.");
6365
return;
6466
}
6567

@@ -70,7 +72,7 @@ public void ShowRegisteredHandlers()
7072
sb.AppendLine($"- Message: {registration.MessageTypeName}, Handler: {registration.HandlerClassName}, IsAsync: {registration.IsAsync}");
7173
}
7274

73-
logger.LogInformation(sb.ToString());
75+
_logger.LogInformation(sb.ToString());
7476
}
7577

7678
[DebuggerStepThrough]
@@ -225,13 +227,13 @@ private IEnumerable<HandlerRegistration> GetHandlersForType(Type type)
225227
public static T GetOrCreateMiddleware<T>(IServiceProvider serviceProvider) where T : class
226228
{
227229
// Check cache first - if it's there, it means it's not registered in DI
228-
if (_middlewareCache.TryGetValue(typeof(T), out var cachedInstance))
230+
if (_middlewareCache.TryGetValue(typeof(T), out object? cachedInstance))
229231
return (T)cachedInstance;
230232

231233
// Try to get from DI - if registered, always use DI (respects service lifetime)
232-
var middlewareFromDI = serviceProvider.GetService<T>();
233-
if (middlewareFromDI != null)
234-
return middlewareFromDI;
234+
var middleware = serviceProvider.GetService<T>();
235+
if (middleware != null)
236+
return middleware;
235237

236238
// Not in DI, create and cache our own instance
237239
return (T)_middlewareCache.GetOrAdd(typeof(T), type =>
@@ -263,9 +265,9 @@ public static T GetOrCreateMiddleware<T>(IServiceProvider serviceProvider) where
263265
if (asyncMethod == null)
264266
return null;
265267

266-
HandleAsyncDelegate asyncDelegate = (IMediator mediator, object message, CancellationToken ct, Type? returnType) =>
268+
HandleAsyncDelegate asyncDelegate = (mediator, message, ct, returnType) =>
267269
{
268-
var taskObj = asyncMethod.Invoke(null, new object?[] { mediator, message, ct, returnType });
270+
object? taskObj = asyncMethod.Invoke(null, [mediator, message, ct, returnType]);
269271
return taskObj is ValueTask<object?> vt ? vt : (ValueTask<object?>)taskObj!;
270272
};
271273

@@ -275,10 +277,7 @@ public static T GetOrCreateMiddleware<T>(IServiceProvider serviceProvider) where
275277
var syncMethod = wrapperClosed.GetMethod("UntypedHandle", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static);
276278
if (syncMethod != null)
277279
{
278-
syncDelegate = (IMediator mediator, object message, CancellationToken ct, Type? returnType) =>
279-
{
280-
return syncMethod.Invoke(null, new object?[] { mediator, message, ct, returnType });
281-
};
280+
syncDelegate = (mediator, message, ct, returnType) => syncMethod.Invoke(null, [mediator, message, ct, returnType]);
282281
}
283282
}
284283

src/Foundatio.Mediator/DIRegistrationGenerator.cs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public static void Execute(SourceProductionContext context, List<HandlerInfo> ha
1616
source.AddGeneratedFileHeader();
1717

1818
source.AppendLine("using Microsoft.Extensions.DependencyInjection;");
19+
source.AppendLine("using Microsoft.Extensions.DependencyInjection.Extensions;");
1920
source.AppendLine("using System;");
2021
source.AppendLine("using System.Diagnostics;");
2122
source.AppendLine("using System.Diagnostics.CodeAnalysis;");
@@ -33,54 +34,59 @@ public static void Execute(SourceProductionContext context, List<HandlerInfo> ha
3334
source.AppendLine(" public static void AddHandlers(this IServiceCollection services)");
3435
source.AppendLine(" {");
3536
source.AppendLine(" // Register HandlerRegistration instances keyed by message type name");
36-
source.AppendLine(" // Optionally register handler classes into DI based on MediatorHandlerLifetime setting");
3737
source.AppendLine();
3838
source.IncrementIndent().IncrementIndent();
3939

40-
bool registerHandlers = !string.Equals(handlerLifetime, "None", StringComparison.OrdinalIgnoreCase);
40+
string lifetimeMethod;
41+
if (String.Equals(handlerLifetime, "Transient", StringComparison.OrdinalIgnoreCase))
42+
lifetimeMethod = "TryAddTransient";
43+
else if (String.Equals(handlerLifetime, "Scoped", StringComparison.OrdinalIgnoreCase))
44+
lifetimeMethod = "TryAddScoped";
45+
else
46+
lifetimeMethod = "TryAddSingleton";
4147

4248
foreach (var handler in handlers)
4349
{
4450
string handlerClassName = HandlerGenerator.GetHandlerClassName(handler);
4551

4652
// Register handler in DI for non-static handler classes when lifetime != Singleton
47-
if (registerHandlers && !handler.IsStatic)
53+
if (handler is { IsStatic: false, IsGenericHandlerClass: false })
4854
{
49-
string lifetimeMethod = "";
50-
if (String.Equals(handlerLifetime, "Transient", StringComparison.OrdinalIgnoreCase))
51-
lifetimeMethod = "AddTransient";
52-
if (String.Equals(handlerLifetime, "Scoped", StringComparison.OrdinalIgnoreCase))
53-
lifetimeMethod = "AddScoped";
54-
if (String.Equals(handlerLifetime, "Singleton", StringComparison.OrdinalIgnoreCase))
55-
lifetimeMethod = "AddSingleton";
56-
57-
if (!String.IsNullOrEmpty(lifetimeMethod))
58-
source.AppendLine($"services.{lifetimeMethod}<{handler.FullName}>();");
55+
source.AppendLine($"services.{lifetimeMethod}<{handler.FullName}>();");
5956
}
6057

6158
if (handler.IsGenericHandlerClass)
6259
{
63-
// open generic registration
6460
if (handler is not { MessageGenericTypeDefinitionFullName: not null, GenericArity: > 0 })
6561
continue;
6662

67-
// Build unbound generic typeof expressions
68-
string wrapperTypeOf = handler.GenericArity switch
63+
string genericArity = handler.GenericArity switch
6964
{
70-
1 => $"typeof({handlerClassName}<>)",
71-
2 => $"typeof({handlerClassName}<,>)",
72-
3 => $"typeof({handlerClassName}<,,>)",
73-
4 => $"typeof({handlerClassName}<,,,>)",
74-
5 => $"typeof({handlerClassName}<,,,,>)",
75-
6 => $"typeof({handlerClassName}<,,,,,>)",
76-
7 => $"typeof({handlerClassName}<,,,,,,>)",
77-
8 => $"typeof({handlerClassName}<,,,,,,,>)",
78-
9 => $"typeof({handlerClassName}<,,,,,,,,>)",
79-
10 => $"typeof({handlerClassName}<,,,,,,,,,>)",
80-
_ => $"typeof({handlerClassName}<>)" // fallback
65+
1 => "<>",
66+
2 => "<,>",
67+
3 => "<,,>",
68+
4 => "<,,,>",
69+
5 => "<,,,,>",
70+
6 => "<,,,,,>",
71+
7 => "<,,,,,,>",
72+
8 => "<,,,,,,,>",
73+
9 => "<,,,,,,,,>",
74+
10 => "<,,,,,,,,,>",
75+
_ => "<>)" // fallback
8176
};
77+
78+
string wrapperTypeOf = $"typeof({handlerClassName}{genericArity})";
8279
string msgTypeOf = $"typeof({handler.MessageGenericTypeDefinitionFullName})";
83-
source.AppendLine($"// Open generic handler registration for {handler.MessageGenericTypeDefinitionFullName}");
80+
if (!handler.IsStatic)
81+
{
82+
string handlerFullName = handler.FullName;
83+
int index = handlerFullName.IndexOf('<');
84+
if (index > 0)
85+
handlerFullName = handlerFullName.Substring(0, index);
86+
source.AppendLine($"services.{lifetimeMethod}(typeof({handlerFullName}{genericArity}));");
87+
88+
}
89+
8490
source.AppendLine($"services.AddSingleton(new OpenGenericHandlerDescriptor({msgTypeOf}, {wrapperTypeOf}, {handler.IsAsync.ToString().ToLower()}));");
8591
}
8692
else

src/Foundatio.Mediator/HandlerAnalyzer.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,12 @@ public static List<HandlerInfo> GetHandlers(GeneratorSyntaxContext context)
131131
if (classSymbol.IsGenericType)
132132
{
133133
genericParamNames = classSymbol.TypeParameters.Select(tp => tp.Name).ToArray();
134-
genericConstraints = classSymbol.TypeParameters.Select(tp => BuildConstraintClause(tp)).Where(s => s.Length > 0).ToArray();
134+
genericConstraints = classSymbol.TypeParameters.Select(BuildConstraintClause).Where(s => s.Length > 0).ToArray();
135135
}
136136
else
137137
{
138-
genericParamNames = Array.Empty<string>();
139-
genericConstraints = Array.Empty<string>();
138+
genericParamNames = [];
139+
genericConstraints = [];
140140
}
141141

142142
handlers.Add(new HandlerInfo
@@ -198,7 +198,7 @@ private static IEnumerable<IMethodSymbol> GetMethods(INamedTypeSymbol targetSymb
198198

199199
foreach (var methodSymbol in methodSymbols)
200200
{
201-
var signature = BuildMethodSignature(methodSymbol);
201+
string signature = BuildMethodSignature(methodSymbol);
202202

203203
if (!methods.ContainsKey(signature))
204204
methods.Add(signature, methodSymbol);
@@ -218,11 +218,11 @@ private static string BuildMethodSignature(IMethodSymbol method)
218218
if (method.Parameters.Length == 0)
219219
return method.Name + "()";
220220

221-
var parts = new string[method.Parameters.Length];
221+
string[] parts = new string[method.Parameters.Length];
222222
for (int i = 0; i < method.Parameters.Length; i++)
223223
parts[i] = method.Parameters[i].Type.ToDisplayString();
224224

225-
return method.Name + "(" + string.Join(",", parts) + ")";
225+
return method.Name + "(" + String.Join(",", parts) + ")";
226226
}
227227

228228
private static readonly string[] ValidHandlerMethodNames = [
@@ -234,7 +234,6 @@ private static string BuildMethodSignature(IMethodSymbol method)
234234

235235
private static string BuildConstraintClause(ITypeParameterSymbol tp)
236236
{
237-
// Order: class/struct/unmanaged first, then specific constraint types, then notnull, then new()
238237
var ordered = new List<string>();
239238

240239
if (tp.HasReferenceTypeConstraint)
@@ -246,8 +245,7 @@ private static string BuildConstraintClause(ITypeParameterSymbol tp)
246245

247246
foreach (var c in tp.ConstraintTypes)
248247
{
249-
var display = c.ToDisplayString();
250-
// Avoid duplicating if already implicitly covered (rare, but defensive)
248+
string display = c.ToDisplayString();
251249
if (!ordered.Contains(display))
252250
ordered.Add(display);
253251
}
@@ -258,8 +256,8 @@ private static string BuildConstraintClause(ITypeParameterSymbol tp)
258256
ordered.Add("new()");
259257

260258
if (ordered.Count == 0)
261-
return string.Empty;
259+
return String.Empty;
262260

263-
return $"where {tp.Name} : {string.Join(", ", ordered)}";
261+
return $"where {tp.Name} : {String.Join(", ", ordered)}";
264262
}
265263
}

src/Foundatio.Mediator/HandlerGenerator.cs

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -569,27 +569,10 @@ private static void GenerateGetOrCreateHandler(IndentedStringBuilder source, Han
569569
{
570570
source.AppendLine()
571571
.AppendLines($$"""
572-
private static {{handler.FullName}}? _handler;
573-
private static readonly global::System.Threading.Lock _lock = new();
574-
575572
[DebuggerStepThrough]
576573
private static {{handler.FullName}} GetOrCreateHandler(IServiceProvider serviceProvider)
577574
{
578-
if (_handler != null)
579-
return _handler;
580-
581-
var handlerFromDI = serviceProvider.GetService<{{handler.FullName}}>();
582-
if (handlerFromDI != null)
583-
return handlerFromDI;
584-
585-
lock (_lock)
586-
{
587-
if (_handler != null)
588-
return _handler;
589-
590-
_handler = ActivatorUtilities.CreateInstance<{{handler.FullName}}>(serviceProvider);
591-
return _handler;
592-
}
575+
return serviceProvider.GetRequiredService<{{handler.FullName}}>();
593576
}
594577
""");
595578
}

src/Foundatio.Mediator/Utility/Helpers.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ private static string GetToolVersion()
3737

3838
public static string ToIdentifier(this string name)
3939
{
40-
if (string.IsNullOrEmpty(name))
40+
if (String.IsNullOrEmpty(name))
4141
return String.Empty;
4242

43-
return new String(name.Select(c => char.IsLetterOrDigit(c) || c == '_' ? c : '_').ToArray());
43+
return new string(name.Select(c => Char.IsLetterOrDigit(c) || c == '_' ? c : '_').ToArray());
4444
}
4545

4646
public static string ToCamelCase(this string name)
4747
{
48-
if (string.IsNullOrEmpty(name))
48+
if (String.IsNullOrEmpty(name))
4949
return String.Empty;
5050

51-
return char.ToLower(name[0]) + name.Substring(1);
51+
return Char.ToLower(name[0]) + name.Substring(1);
5252
}
5353
}

tests/Foundatio.Mediator.Tests/BasicHandlerGenerationTests.GeneratesWrapperForSimpleHandler.verified.txt

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -110,27 +110,10 @@ internal static class PingHandler_Ping_Handler
110110
return HandlerScope.GetOrCreate(mediator, cancellationToken);
111111
}
112112

113-
private static PingHandler? _handler;
114-
private static readonly global::System.Threading.Lock _lock = new();
115-
116113
[DebuggerStepThrough]
117114
private static PingHandler GetOrCreateHandler(IServiceProvider serviceProvider)
118115
{
119-
if (_handler != null)
120-
return _handler;
121-
122-
var handlerFromDI = serviceProvider.GetService<PingHandler>();
123-
if (handlerFromDI != null)
124-
return handlerFromDI;
125-
126-
lock (_lock)
127-
{
128-
if (_handler != null)
129-
return _handler;
130-
131-
_handler = ActivatorUtilities.CreateInstance<PingHandler>(serviceProvider);
132-
return _handler;
133-
}
116+
return serviceProvider.GetRequiredService<PingHandler>();
134117
}
135118
}
136119

@@ -145,6 +128,7 @@ internal static class PingHandler_Ping_Handler
145128
#nullable enable
146129

147130
using Microsoft.Extensions.DependencyInjection;
131+
using Microsoft.Extensions.DependencyInjection.Extensions;
148132
using System;
149133
using System.Diagnostics;
150134
using System.Diagnostics.CodeAnalysis;
@@ -162,8 +146,8 @@ public static class Tests_MediatorHandlers
162146
public static void AddHandlers(this IServiceCollection services)
163147
{
164148
// Register HandlerRegistration instances keyed by message type name
165-
// Optionally register handler classes into DI based on MediatorHandlerLifetime setting
166149

150+
services.TryAddSingleton<PingHandler>();
167151
services.AddHandler(new HandlerRegistration(
168152
MessageTypeKey.Get(typeof(Ping)),
169153
"PingHandler_Ping_Handler",

tests/Foundatio.Mediator.Tests/GenericResultTests.cs

Whitespace-only changes.

0 commit comments

Comments
 (0)