@@ -20,31 +20,82 @@ internal interface ICodeGenerationVisitor : INodeVisitor
2020 string GenerateContainerFile ( ) ;
2121 void VisitICreateFunctionNodeBase ( ICreateFunctionNodeBase element ) ;
2222 void VisitIElementNode ( IElementNode elementNode ) ;
23+ CodeGenerationFunctionVisitor CreateNestedFunctionVisitor (
24+ ReturnTypeStatus returnTypeStatus ,
25+ AsyncAwaitStatus asyncAwaitStatus ) ;
2326}
2427
25- internal sealed class CodeGenerationVisitor : ICodeGenerationVisitor
28+ internal sealed class CodeGenerationVisitor : CodeGenerationVisitorBase
2629{
27- private readonly StringBuilder _code = new ( ) ;
30+ internal CodeGenerationVisitor (
31+ WellKnownTypes wellKnownTypes ,
32+ WellKnownTypesCollections wellKnownTypesCollections ,
33+ Func < StringBuilder , ReturnTypeStatus , AsyncAwaitStatus , CodeGenerationFunctionVisitor > codeGenerationFunctionVisitorFactory )
34+ : base ( new ( ) , ReturnTypeStatus . Ordinary , AsyncAwaitStatus . No , wellKnownTypes , wellKnownTypesCollections , codeGenerationFunctionVisitorFactory )
35+ {
36+ }
37+ }
38+
39+ internal sealed class CodeGenerationFunctionVisitor : CodeGenerationVisitorBase
40+ {
41+ internal CodeGenerationFunctionVisitor (
42+ // parameters
43+ StringBuilder code ,
44+ ReturnTypeStatus returnTypeStatus ,
45+ AsyncAwaitStatus asyncAwaitStatus ,
46+
47+ // dependencies
48+ WellKnownTypes wellKnownTypes ,
49+ WellKnownTypesCollections wellKnownTypesCollections ,
50+ Func < StringBuilder , ReturnTypeStatus , AsyncAwaitStatus , CodeGenerationFunctionVisitor > codeGenerationFunctionVisitorFactory )
51+ : base ( code , returnTypeStatus , asyncAwaitStatus , wellKnownTypes , wellKnownTypesCollections , codeGenerationFunctionVisitorFactory )
52+ {
53+ }
54+ }
55+
56+ internal class CodeGenerationVisitorBase : ICodeGenerationVisitor
57+ {
58+ private readonly StringBuilder _code ;
2859 private readonly WellKnownTypes _wellKnownTypes ;
2960 private readonly WellKnownTypesCollections _wellKnownTypesCollections ;
30-
31- internal CodeGenerationVisitor (
61+ private readonly Func < StringBuilder , ReturnTypeStatus , AsyncAwaitStatus , CodeGenerationFunctionVisitor > _codeGenerationFunctionVisitorFactory ;
62+ private readonly ReturnTypeStatus _returnTypeStatus ;
63+ private readonly AsyncAwaitStatus _asyncAwaitStatus ;
64+
65+ internal CodeGenerationVisitorBase (
66+ // parameters
67+ StringBuilder code ,
68+ ReturnTypeStatus returnTypeStatus ,
69+ AsyncAwaitStatus asyncAwaitStatus ,
70+
71+ // dependencies
3272 WellKnownTypes wellKnownTypes ,
33- WellKnownTypesCollections wellKnownTypesCollections )
73+ WellKnownTypesCollections wellKnownTypesCollections ,
74+ Func < StringBuilder , ReturnTypeStatus , AsyncAwaitStatus , CodeGenerationFunctionVisitor > codeGenerationFunctionVisitorFactory )
3475 {
76+ _code = code ;
3577 _wellKnownTypes = wellKnownTypes ;
3678 _wellKnownTypesCollections = wellKnownTypesCollections ;
79+ _codeGenerationFunctionVisitorFactory = codeGenerationFunctionVisitorFactory ;
80+ _returnTypeStatus = returnTypeStatus ;
81+ _asyncAwaitStatus = asyncAwaitStatus ;
3782 }
3883
84+ private bool CurrentFunctionAsyncAwait =>
85+ ( _returnTypeStatus . HasFlag ( ReturnTypeStatus . ValueTask ) || _returnTypeStatus . HasFlag ( ReturnTypeStatus . Task ) )
86+ && _asyncAwaitStatus is AsyncAwaitStatus . Yes ;
87+
3988 public void VisitIContainerNode ( IContainerNode container ) =>
4089 container . GetGenerator ( ) . Generate ( _code , this ) ;
4190
4291 public void VisitICreateContainerFunctionNode ( ICreateContainerFunctionNode createContainerFunction )
4392 {
44- var asyncPrefix = createContainerFunction . InitializationAwaited
93+ var isAsyncAwait = createContainerFunction . ReturnTypeStatus . HasFlag ( ReturnTypeStatus . Task ) ||
94+ createContainerFunction . ReturnTypeStatus . HasFlag ( ReturnTypeStatus . ValueTask ) ;
95+ var asyncPrefix = isAsyncAwait
4596 ? "async "
4697 : "" ;
47- var awaitPrefix = createContainerFunction . InitializationAwaited
98+ var awaitPrefix = isAsyncAwait
4899 ? "await "
49100 : "" ;
50101
@@ -129,12 +180,12 @@ private void GenerateInitialization(IFunctionCallNode? maybeInitialization, stri
129180 {
130181 if ( maybeInitialization is { } initialization )
131182 {
132- var asyncPrefix = initialization . Awaited
183+ var asyncPrefix = CurrentFunctionAsyncAwait
133184 ? "await "
134185 : "" ;
135186
136187 _code . AppendLine (
137- $ "{ asyncPrefix } { ownerReference } .{ initialization . FunctionName } ({ string . Join ( ", " , initialization . Parameters . Select ( p => $ "{ p . Item1 . Reference . PrefixAtIfKeyword ( ) } : { p . Item2 . Reference } ") ) } );") ;
188+ $ "{ asyncPrefix } { ownerReference } .{ initialization . FunctionName ( _returnTypeStatus ) } ({ string . Join ( ", " , initialization . Parameters . Select ( p => $ "{ p . Item1 . Reference . PrefixAtIfKeyword ( ) } : { p . Item2 . Reference } ") ) } );") ;
138189 }
139190 }
140191
@@ -178,7 +229,11 @@ public void VisitIRangedInstanceInterfaceFunctionNode(IRangedInstanceInterfaceFu
178229 . AppendIf (
179230 $ "{ rangedInstanceInterfaceFunctionNode . TransientScopeDisposalNode . TypeFullName } { rangedInstanceInterfaceFunctionNode . TransientScopeDisposalNode . Reference } ",
180231 rangedInstanceInterfaceFunctionNode . IsTransientScopeDisposalAsParameter ) ) ;
181- _code . AppendLine ( $ "{ rangedInstanceInterfaceFunctionNode . ReturnedTypeFullName } { rangedInstanceInterfaceFunctionNode . Name } ({ parameter } );") ;
232+ var consideredStatuses = Enum . GetValues ( typeof ( ReturnTypeStatus ) )
233+ . OfType < ReturnTypeStatus > ( )
234+ . Where ( r => rangedInstanceInterfaceFunctionNode . ReturnTypeStatus . HasFlag ( r ) ) ;
235+ foreach ( var status in consideredStatuses )
236+ _code . AppendLine ( $ "{ rangedInstanceInterfaceFunctionNode . ReturnedTypeFullName ( status ) } { rangedInstanceInterfaceFunctionNode . Name ( status ) } ({ parameter } );") ;
182237 }
183238
184239 public void VisitIRangedInstanceFunctionGroupNode ( IRangedInstanceFunctionGroupNode rangedInstanceFunctionGroupNode )
@@ -227,19 +282,23 @@ public void VisitIWrappedAsyncFunctionCallNode(IWrappedAsyncFunctionCallNode fun
227282 var typeParameters = functionCallNode . TypeParameters . Any ( )
228283 ? $ "<{ string . Join ( ", " , functionCallNode . TypeParameters . Select ( p => p . Name ) ) } >"
229284 : "" ;
230- var call = $ "{ owner } { functionCallNode . FunctionName } { typeParameters } ({ parameters } )";
285+ string functionName ;
286+ if ( functionCallNode . CalledFunction . ReturnTypeStatus . HasFlag ( ReturnTypeStatus . ValueTask ) )
287+ functionName = functionCallNode . FunctionName ( ReturnTypeStatus . ValueTask ) ;
288+ else if ( functionCallNode . CalledFunction . ReturnTypeStatus . HasFlag ( ReturnTypeStatus . Task ) )
289+ functionName = functionCallNode . FunctionName ( ReturnTypeStatus . Task ) ;
290+ else
291+ functionName = functionCallNode . FunctionName ( ReturnTypeStatus . Ordinary ) ;
292+ var call = $ "{ owner } { functionName } { typeParameters } ({ parameters } )";
231293 call = functionCallNode . Transformation switch
232294 {
233- AsyncFunctionCallTransformation . ValueTaskFromValueTask => $ "new { typeFullName } ({ call } )",
234- AsyncFunctionCallTransformation . ValueTaskFromForcedValueTask => call ,
235- AsyncFunctionCallTransformation . ValueTaskFromTask => $ "new { typeFullName } ({ call } )",
236- AsyncFunctionCallTransformation . ValueTaskFromForcedTask => $ "new { typeFullName } ({ call } )",
237- AsyncFunctionCallTransformation . ValueTaskFromSync => $ "new { typeFullName } ({ call } )",
238- AsyncFunctionCallTransformation . TaskFromValueTask => $ "{ _wellKnownTypes . Task . FullName ( ) } .FromResult({ call } )",
239- AsyncFunctionCallTransformation . TaskFromForcedValueTask => $ "{ call } .AsTask()",
240- AsyncFunctionCallTransformation . TaskFromTask => $ "{ _wellKnownTypes . Task . FullName ( ) } .FromResult({ call } )",
241- AsyncFunctionCallTransformation . TaskFromForcedTask => call ,
242- AsyncFunctionCallTransformation . TaskFromSync => $ "{ _wellKnownTypes . Task . FullName ( ) } .FromResult({ call } )",
295+ WrappedAsyncTransformation . SameSame => call ,
296+ WrappedAsyncTransformation . ValueTaskFromTask => $ "new { typeFullName } ({ call } )",
297+ WrappedAsyncTransformation . ValueTaskFromSync => $ "new { typeFullName } ({ _wellKnownTypes . Task . FullName ( ) } .FromResult({ call } ))",
298+ WrappedAsyncTransformation . ValueTaskDeeper => $ "new { typeFullName } ({ call } .AsTask())",
299+ WrappedAsyncTransformation . TaskFromValueTask => $ "{ call } .AsTask()",
300+ WrappedAsyncTransformation . TaskFromSync => $ "{ _wellKnownTypes . Task . FullName ( ) } .FromResult({ call } )",
301+ WrappedAsyncTransformation . TaskDeeper => $ "{ _wellKnownTypes . Task . FullName ( ) } .FromResult({ call } )",
243302 _ => throw new ArgumentOutOfRangeException ( nameof ( functionCallNode ) , $ "Switch in DIE type { nameof ( CodeGenerationVisitor ) } is not exhaustive.")
244303 } ;
245304 _code . AppendLine ( $ "{ typeFullName } { functionCallNode . Reference } = ({ typeFullName } ){ call } ;") ;
@@ -261,8 +320,15 @@ private void VisitIFunctionCallNode(IFunctionCallNode functionCallNode)
261320 var typeParameters = functionCallNode . TypeParameters . Any ( )
262321 ? $ "<{ string . Join ( ", " , functionCallNode . TypeParameters . Select ( p => p . Name ) ) } >"
263322 : "" ;
264- var call = $ "{ owner } { functionCallNode . FunctionName } { typeParameters } ({ parameters } )";
265- call = functionCallNode . Awaited ? $ "(await { call } )" : call ;
323+ var functionName = _returnTypeStatus . HasFlag ( ReturnTypeStatus . Ordinary )
324+ ? functionCallNode . FunctionName ( ReturnTypeStatus . Ordinary )
325+ : functionCallNode . CalledFunction . ReturnTypeStatus . HasFlag ( ReturnTypeStatus . ValueTask )
326+ ? functionCallNode . FunctionName ( ReturnTypeStatus . ValueTask )
327+ : functionCallNode . FunctionName ( ReturnTypeStatus . Task ) ;
328+ var call = $ "{ owner } { functionName } { typeParameters } ({ parameters } )";
329+ call = CurrentFunctionAsyncAwait && functionCallNode . CalledFunction is not IMultiFunctionNodeBase { IsAsyncEnumerable : true }
330+ ? $ "(await { call } )"
331+ : call ;
266332 _code . AppendLine ( $ "{ typeFullName } { functionCallNode . Reference } = ({ typeFullName } ){ call } ;") ;
267333 }
268334
@@ -431,6 +497,10 @@ public void VisitIElementNode(IElementNode elementNode)
431497 }
432498 }
433499
500+ public CodeGenerationFunctionVisitor CreateNestedFunctionVisitor ( ReturnTypeStatus returnTypeStatus ,
501+ AsyncAwaitStatus asyncAwaitStatus ) =>
502+ _codeGenerationFunctionVisitorFactory ( _code , returnTypeStatus , asyncAwaitStatus ) ;
503+
434504 public void VisitIImplementationNode ( IImplementationNode implementationNode )
435505 {
436506 if ( implementationNode . UserDefinedInjectionConstructor is not null )
@@ -464,7 +534,7 @@ public void VisitIImplementationNode(IImplementationNode implementationNode)
464534 var initializerParameters =
465535 string . Join ( ", " , init . Parameters . Select ( d => $ "{ d . Name . PrefixAtIfKeyword ( ) } : { d . Element . Reference } ") ) ;
466536
467- var prefix = implementationNode . Awaited
537+ var prefix = implementationNode . InitializerReturnsSomeTask
468538 ? "await "
469539 : implementationNode is { AsyncReference : { } asyncReference , AsyncTypeFullName : { } asyncTypeFullName }
470540 ? $ "{ asyncTypeFullName } { asyncReference } = "
0 commit comments