@@ -16,6 +16,7 @@ public static void Execute(SourceProductionContext context, List<HandlerInfo> ha
1616 source . AddGeneratedFileHeader ( ) ;
1717
1818 source . AppendLine ( "using Microsoft.Extensions.DependencyInjection;" ) ;
19+ source . AppendLine ( "using Microsoft.Extensions.DependencyInjection.Extensions;" ) ;
1920 source . AppendLine ( "using System;" ) ;
2021 source . AppendLine ( "using System.Diagnostics;" ) ;
2122 source . AppendLine ( "using System.Diagnostics.CodeAnalysis;" ) ;
@@ -33,54 +34,59 @@ public static void Execute(SourceProductionContext context, List<HandlerInfo> ha
3334 source . AppendLine ( " public static void AddHandlers(this IServiceCollection services)" ) ;
3435 source . AppendLine ( " {" ) ;
3536 source . AppendLine ( " // Register HandlerRegistration instances keyed by message type name" ) ;
36- source . AppendLine ( " // Optionally register handler classes into DI based on MediatorHandlerLifetime setting" ) ;
3737 source . AppendLine ( ) ;
3838 source . IncrementIndent ( ) . IncrementIndent ( ) ;
3939
40- bool registerHandlers = ! string . Equals ( handlerLifetime , "None" , StringComparison . OrdinalIgnoreCase ) ;
40+ string lifetimeMethod ;
41+ if ( String . Equals ( handlerLifetime , "Transient" , StringComparison . OrdinalIgnoreCase ) )
42+ lifetimeMethod = "TryAddTransient" ;
43+ else if ( String . Equals ( handlerLifetime , "Scoped" , StringComparison . OrdinalIgnoreCase ) )
44+ lifetimeMethod = "TryAddScoped" ;
45+ else
46+ lifetimeMethod = "TryAddSingleton" ;
4147
4248 foreach ( var handler in handlers )
4349 {
4450 string handlerClassName = HandlerGenerator . GetHandlerClassName ( handler ) ;
4551
4652 // Register handler in DI for non-static handler classes when lifetime != Singleton
47- if ( registerHandlers && ! handler . IsStatic )
53+ if ( handler is { IsStatic : false , IsGenericHandlerClass : false } )
4854 {
49- string lifetimeMethod = "" ;
50- if ( String . Equals ( handlerLifetime , "Transient" , StringComparison . OrdinalIgnoreCase ) )
51- lifetimeMethod = "AddTransient" ;
52- if ( String . Equals ( handlerLifetime , "Scoped" , StringComparison . OrdinalIgnoreCase ) )
53- lifetimeMethod = "AddScoped" ;
54- if ( String . Equals ( handlerLifetime , "Singleton" , StringComparison . OrdinalIgnoreCase ) )
55- lifetimeMethod = "AddSingleton" ;
56-
57- if ( ! String . IsNullOrEmpty ( lifetimeMethod ) )
58- source . AppendLine ( $ "services.{ lifetimeMethod } <{ handler . FullName } >();") ;
55+ source . AppendLine ( $ "services.{ lifetimeMethod } <{ handler . FullName } >();") ;
5956 }
6057
6158 if ( handler . IsGenericHandlerClass )
6259 {
63- // open generic registration
6460 if ( handler is not { MessageGenericTypeDefinitionFullName : not null , GenericArity : > 0 } )
6561 continue ;
6662
67- // Build unbound generic typeof expressions
68- string wrapperTypeOf = handler . GenericArity switch
63+ string genericArity = handler . GenericArity switch
6964 {
70- 1 => $ "typeof( { handlerClassName } <>) ",
71- 2 => $ "typeof( { handlerClassName } <,>) ",
72- 3 => $ "typeof( { handlerClassName } <,,>) ",
73- 4 => $ "typeof( { handlerClassName } <,,,>) ",
74- 5 => $ "typeof( { handlerClassName } <,,,,>) ",
75- 6 => $ "typeof( { handlerClassName } <,,,,,>) ",
76- 7 => $ "typeof( { handlerClassName } <,,,,,,>) ",
77- 8 => $ "typeof( { handlerClassName } <,,,,,,,>) ",
78- 9 => $ "typeof( { handlerClassName } <,,,,,,,,>) ",
79- 10 => $ "typeof( { handlerClassName } <,,,,,,,,,>) ",
80- _ => $ "typeof( { handlerClassName } <>)" // fallback
65+ 1 => "<> ",
66+ 2 => " <,>",
67+ 3 => " <,,>",
68+ 4 => " <,,,>",
69+ 5 => " <,,,,>",
70+ 6 => " <,,,,,>",
71+ 7 => " <,,,,,,>",
72+ 8 => " <,,,,,,,>",
73+ 9 => " <,,,,,,,,>",
74+ 10 => " <,,,,,,,,,>",
75+ _ => " <>)" // fallback
8176 } ;
77+
78+ string wrapperTypeOf = $ "typeof({ handlerClassName } { genericArity } )";
8279 string msgTypeOf = $ "typeof({ handler . MessageGenericTypeDefinitionFullName } )";
83- source . AppendLine ( $ "// Open generic handler registration for { handler . MessageGenericTypeDefinitionFullName } ") ;
80+ if ( ! handler . IsStatic )
81+ {
82+ string handlerFullName = handler . FullName ;
83+ int index = handlerFullName . IndexOf ( '<' ) ;
84+ if ( index > 0 )
85+ handlerFullName = handlerFullName . Substring ( 0 , index ) ;
86+ source . AppendLine ( $ "services.{ lifetimeMethod } (typeof({ handlerFullName } { genericArity } ));") ;
87+
88+ }
89+
8490 source . AppendLine ( $ "services.AddSingleton(new OpenGenericHandlerDescriptor({ msgTypeOf } , { wrapperTypeOf } , { handler . IsAsync . ToString ( ) . ToLower ( ) } ));") ;
8591 }
8692 else
0 commit comments