@@ -23,21 +23,40 @@ private void InitializeUnionSourceGen(
2323 IncrementalGeneratorInitializationContext context ,
2424 IncrementalValueProvider < GeneratorOptions > options )
2525 {
26- InitializeUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_2_TYPES ) ;
27- InitializeUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_3_TYPES ) ;
28- InitializeUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_4_TYPES ) ;
29- InitializeUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_5_TYPES ) ;
26+ InitializeGenericUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_2_TYPES ) ;
27+ InitializeGenericUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_3_TYPES ) ;
28+ InitializeGenericUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_4_TYPES ) ;
29+ InitializeGenericUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_5_TYPES ) ;
30+ InitializeNonGenericUnionSourceGen ( context , options , Constants . Attributes . Union . FULL_NAME_AD_HOCH ) ;
3031 }
3132
32- private void InitializeUnionSourceGen (
33+ private void InitializeGenericUnionSourceGen (
34+ IncrementalGeneratorInitializationContext context ,
35+ IncrementalValueProvider < GeneratorOptions > options ,
36+ string fullyQualifiedMetadataName )
37+ {
38+ InitializeUnionSourceGen ( context , options , fullyQualifiedMetadataName , IsGenericCandidate , GetSourceGenContextOrNullForGeneric ) ;
39+ }
40+
41+ private void InitializeNonGenericUnionSourceGen (
3342 IncrementalGeneratorInitializationContext context ,
3443 IncrementalValueProvider < GeneratorOptions > options ,
3544 string fullyQualifiedMetadataName )
45+ {
46+ InitializeUnionSourceGen ( context , options , fullyQualifiedMetadataName , IsNonGenericCandidate , GetSourceGenContextOrNullForNonGeneric ) ;
47+ }
48+
49+ private void InitializeUnionSourceGen (
50+ IncrementalGeneratorInitializationContext context ,
51+ IncrementalValueProvider < GeneratorOptions > options ,
52+ string fullyQualifiedMetadataName ,
53+ Func < SyntaxNode , CancellationToken , bool > isCandate ,
54+ Func < GeneratorAttributeSyntaxContext , CancellationToken , SourceGenContext ? > getSourceGenContextOrNull )
3655 {
3756 var unionTypeOrError = context . SyntaxProvider
3857 . ForAttributeWithMetadataName ( fullyQualifiedMetadataName ,
39- IsCandidate ,
40- GetSourceGenContextOrNull )
58+ isCandate ,
59+ getSourceGenContextOrNull )
4160 . SelectMany ( static ( state , _ ) => state . HasValue
4261 ? [ state . Value ]
4362 : ImmutableArray < SourceGenContext > . Empty ) ;
@@ -52,7 +71,7 @@ private void InitializeUnionSourceGen(
5271 InitializeExceptionReporting ( context , unionTypeOrError ) ;
5372 }
5473
55- private bool IsCandidate ( SyntaxNode syntaxNode , CancellationToken cancellationToken )
74+ private bool IsGenericCandidate ( SyntaxNode syntaxNode , CancellationToken cancellationToken )
5675 {
5776 try
5877 {
@@ -70,6 +89,24 @@ StructDeclarationSyntax structDeclaration when IsUnionCandidate(structDeclaratio
7089 }
7190 }
7291
92+ private bool IsNonGenericCandidate ( SyntaxNode syntaxNode , CancellationToken cancellationToken )
93+ {
94+ try
95+ {
96+ return syntaxNode switch
97+ {
98+ ClassDeclarationSyntax => true ,
99+ StructDeclarationSyntax => true ,
100+ _ => false
101+ } ;
102+ }
103+ catch ( Exception ex )
104+ {
105+ Logger . LogError ( "Error during checking whether a syntax node is a discriminated union candidate" , exception : ex ) ;
106+ return false ;
107+ }
108+ }
109+
73110 private bool IsUnionCandidate ( TypeDeclarationSyntax typeDeclaration )
74111 {
75112 var isCandidate = ! typeDeclaration . IsGeneric ( ) ;
@@ -86,7 +123,85 @@ private bool IsUnionCandidate(TypeDeclarationSyntax typeDeclaration)
86123 return isCandidate ;
87124 }
88125
89- private SourceGenContext ? GetSourceGenContextOrNull ( GeneratorAttributeSyntaxContext context , CancellationToken cancellationToken )
126+ private SourceGenContext ? GetSourceGenContextOrNullForGeneric ( GeneratorAttributeSyntaxContext context , CancellationToken cancellationToken )
127+ {
128+ return GetSourceGenContextOrNull (
129+ context ,
130+ ( tds , data ) =>
131+ {
132+ var attributeType = data . AttributeClass ;
133+
134+ if ( attributeType is null )
135+ {
136+ Logger . LogDebug ( "The attribute type is null" , tds ) ;
137+ return null ;
138+ }
139+
140+ if ( attributeType . TypeKind == TypeKind . Error )
141+ {
142+ Logger . LogDebug ( "The attribute type is erroneous" , tds ) ;
143+ return null ;
144+ }
145+
146+ if ( attributeType . TypeArguments . IsDefaultOrEmpty )
147+ return null ;
148+
149+ return attributeType . TypeArguments ;
150+ } ,
151+ cancellationToken ) ;
152+ }
153+
154+ private SourceGenContext ? GetSourceGenContextOrNullForNonGeneric ( GeneratorAttributeSyntaxContext context , CancellationToken cancellationToken )
155+ {
156+ return GetSourceGenContextOrNull (
157+ context ,
158+ ( tds , data ) =>
159+ {
160+ var attributeType = data . AttributeClass ;
161+
162+ if ( attributeType is null )
163+ {
164+ Logger . LogDebug ( "The attribute type is null" , tds ) ;
165+ return null ;
166+ }
167+
168+ if ( attributeType . TypeKind == TypeKind . Error )
169+ {
170+ Logger . LogDebug ( "The attribute type is erroneous" , tds ) ;
171+ return null ;
172+ }
173+
174+ if ( data . ConstructorArguments . IsDefaultOrEmpty )
175+ return null ;
176+
177+ var types = new List < ITypeSymbol > ( ) ;
178+ var foundNull = false ;
179+
180+ for ( var i = 0 ; i < data . ConstructorArguments . Length ; i ++ )
181+ {
182+ var argument = data . ConstructorArguments [ i ] ;
183+
184+ if ( argument . IsNull )
185+ {
186+ foundNull = true ;
187+ continue ;
188+ }
189+
190+ if ( foundNull || argument . Value is not ITypeSymbol type || type . TypeKind == TypeKind . Error )
191+ return null ;
192+
193+ types . Add ( type ) ;
194+ }
195+
196+ return types ;
197+ } ,
198+ cancellationToken ) ;
199+ }
200+
201+ private SourceGenContext ? GetSourceGenContextOrNull (
202+ GeneratorAttributeSyntaxContext context ,
203+ Func < TypeDeclarationSyntax , AttributeData , IReadOnlyList < ITypeSymbol > ? > getMemberTypes ,
204+ CancellationToken cancellationToken )
90205 {
91206 var tds = ( TypeDeclarationSyntax ) context . TargetNode ;
92207
@@ -109,26 +224,17 @@ private bool IsUnionCandidate(TypeDeclarationSyntax typeDeclaration)
109224 return null ;
110225 }
111226
112- var attributeType = context . Attributes [ 0 ] . AttributeClass ;
113-
114- if ( attributeType is null )
115- {
116- Logger . LogDebug ( "The attribute type is null" , tds ) ;
117- return null ;
118- }
119-
120- if ( attributeType . TypeArguments . IsDefaultOrEmpty )
121- return null ;
227+ var attributeData = context . Attributes [ 0 ] ;
228+ var memberTypeSymbols = getMemberTypes ( tds , attributeData ) ;
122229
123- if ( attributeType . TypeKind == TypeKind . Error )
230+ if ( memberTypeSymbols is null )
124231 {
125- Logger . LogDebug ( "The attribute type is erroneous" , tds ) ;
126232 return null ;
127233 }
128234
129- if ( attributeType . Arity < 2 )
235+ if ( memberTypeSymbols . Count < 2 )
130236 {
131- Logger . LogDebug ( $ "Expected the attribute type to have at least 2 type arguments but found { attributeType . Arity . ToString ( ) } ", tds ) ;
237+ Logger . LogDebug ( $ "Expected the union to have at least 2 member types but found { memberTypeSymbols . Count } ", tds ) ;
132238 return null ;
133239 }
134240
@@ -146,12 +252,12 @@ private bool IsUnionCandidate(TypeDeclarationSyntax typeDeclaration)
146252 return new SourceGenContext ( new SourceGenError ( "Could not fetch type information for code generation of a discriminated union" , tds ) ) ;
147253
148254 var settings = new AdHocUnionSettings ( context . Attributes [ 0 ] ,
149- attributeType . Arity ) ;
150- var memberTypeStates = attributeType . Arity == 0 ? [ ] : new AdHocUnionMemberTypeState [ attributeType . Arity ] ;
255+ memberTypeSymbols . Count ) ;
256+ var memberTypeStates = new AdHocUnionMemberTypeState [ memberTypeSymbols . Count ] ;
151257
152- for ( var i = 0 ; i < attributeType . TypeArguments . Length ; i ++ )
258+ for ( var i = 0 ; i < memberTypeSymbols . Count ; i ++ )
153259 {
154- var memberType = attributeType . TypeArguments [ i ] ;
260+ var memberType = memberTypeSymbols [ i ] ;
155261
156262 if ( memberType . TypeKind == TypeKind . Error )
157263 {
@@ -165,7 +271,7 @@ private bool IsUnionCandidate(TypeDeclarationSyntax typeDeclaration)
165271
166272 var typeDuplicateCounter = 0 ;
167273
168- for ( var j = 0 ; j < attributeType . TypeArguments . Length ; j ++ )
274+ for ( var j = 0 ; j < memberTypeSymbols . Count ; j ++ )
169275 {
170276 if ( j == i )
171277 {
@@ -175,7 +281,7 @@ private bool IsUnionCandidate(TypeDeclarationSyntax typeDeclaration)
175281 continue ;
176282 }
177283
178- if ( ! SymbolEqualityComparer . Default . Equals ( memberType , attributeType . TypeArguments [ j ] ) )
284+ if ( ! SymbolEqualityComparer . Default . Equals ( memberType , memberTypeSymbols [ j ] ) )
179285 continue ;
180286
181287 if ( j > i && typeDuplicateCounter != 0 )
0 commit comments