@@ -47,7 +47,7 @@ private void Replace(SmartStringBuilder smartStringBuilder, MethodInfo? method,
4747 userImpl = string . Empty ;
4848 needsExplicitReturn = true ;
4949 }
50- else if ( method . ReturnType . IsAwaitable ( ) )
50+ else if ( method . ReturnType . IsAwaitable ( out _ ) )
5151 {
5252 modifier = "async" ;
5353 userImpl = $ "await { GetMethodPrefix ( method ) } .{ method . Name } ();";
@@ -97,7 +97,7 @@ protected string GetPassArgumentsDirect()
9797 ) ;
9898 }
9999
100- internal class SyncDeclarationsProvider ( BenchmarkCase benchmark ) : DeclarationsProvider ( benchmark )
100+ internal sealed class SyncDeclarationsProvider ( BenchmarkCase benchmark ) : DeclarationsProvider ( benchmark )
101101 {
102102 public override string [ ] GetExtraFields ( ) => [ ] ;
103103
@@ -175,8 +175,13 @@ private string GetPassArguments()
175175 ) ;
176176 }
177177
178- internal class AsyncDeclarationsProvider ( BenchmarkCase benchmark ) : DeclarationsProvider ( benchmark )
178+ internal abstract class AsyncDeclarationsProviderBase ( BenchmarkCase benchmark ) : DeclarationsProvider ( benchmark )
179179 {
180+ // Type used to drive the WorkloadCore builder selection. For ordinary awaitables it's the workload
181+ // method's own return type, but `IAsyncEnumerable<T>` has no GetAwaiter, so AsyncEnumerableDeclarationsProvider
182+ // overrides this to expose the MoveNextAsync awaitable as a proxy.
183+ protected virtual Type WorkloadAwaitableReturnType => Descriptor . WorkloadMethod . ReturnType ;
184+
180185 public override string [ ] GetExtraFields ( ) =>
181186 [
182187 $ "public { typeof ( WorkloadValueTaskSource ) . GetCorrectCSharpTypeName ( ) } workloadContinuerAndValueTaskSource;",
@@ -193,87 +198,6 @@ protected override string GetExtraGlobalSetupImpl()
193198 protected override string GetExtraGlobalCleanupImpl ( )
194199 => "this.__fieldsContainer.workloadContinuerAndValueTaskSource.Complete();" ;
195200
196- protected override SmartStringBuilder ReplaceCore ( SmartStringBuilder smartStringBuilder )
197- {
198- // Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods.
199- int unrollFactor = Benchmark . Job . ResolveValue ( RunMode . UnrollFactorCharacteristic , EnvironmentResolver . Instance ) ;
200- string passArguments = GetPassArgumentsDirect ( ) ;
201- string workloadMethodCall = GetWorkloadMethodCall ( passArguments ) ;
202- bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute ( out var asyncMethodBuilderAttribute ) ;
203- Type workloadCoreReturnType = GetWorkloadCoreReturnType ( hasAsyncMethodBuilderAttribute , Descriptor . WorkloadMethod . ReturnType ) ;
204- string finalReturn = GetFinalReturn ( workloadCoreReturnType ) ;
205- string coreImpl = $$ """
206- private {{ CoreReturnType }} OverheadActionUnroll({{ CoreParameters }} )
207- {
208- return this.OverheadActionNoUnroll(invokeCount * {{ unrollFactor }} , clock);
209- }
210-
211- private {{ CoreReturnType }} OverheadActionNoUnroll({{ CoreParameters }} )
212- {
213- {{ StartClockSyncCode }}
214- while (--invokeCount >= 0)
215- {
216- this.__Overhead({{ passArguments }} );
217- }
218- {{ ReturnSyncCode }}
219- }
220-
221- private {{ CoreReturnType }} WorkloadActionUnroll({{ CoreParameters }} )
222- {
223- return this.WorkloadActionNoUnroll(invokeCount * {{ unrollFactor }} , clock);
224- }
225-
226- private {{ CoreReturnType }} WorkloadActionNoUnroll({{ CoreParameters }} )
227- {
228- this.__fieldsContainer.invokeCount = invokeCount;
229- this.__fieldsContainer.clock = clock;
230- // The source is allocated and the workload loop started in __GlobalSetup,
231- // so this hot path is branchless and allocation-free.
232- return this.__fieldsContainer.workloadContinuerAndValueTaskSource.Continue();
233- }
234-
235- private async void __StartWorkload()
236- {
237- await __WorkloadCore();
238- }
239-
240- {{ asyncMethodBuilderAttribute }}
241- private async {{ workloadCoreReturnType . GetCorrectCSharpTypeName ( ) }} __WorkloadCore()
242- {
243- try
244- {
245- if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.GetIsComplete())
246- {
247- {{ finalReturn }}
248- }
249- while (true)
250- {
251- {{ typeof ( StartedClock ) . GetCorrectCSharpTypeName ( ) }} startedClock = {{ typeof ( ClockExtensions ) . GetCorrectCSharpTypeName ( ) }} .Start(this.__fieldsContainer.clock);
252- while (--this.__fieldsContainer.invokeCount >= 0)
253- {
254- // Necessary because of error CS4004: Cannot await in an unsafe context
255- {{ Descriptor . WorkloadMethod . ReturnType . GetCorrectCSharpTypeName ( ) }} awaitable;
256- unsafe { awaitable = {{ workloadMethodCall }} }
257- await awaitable;
258- }
259- if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed()))
260- {
261- {{ finalReturn }}
262- }
263- }
264- }
265- catch (global::System.Exception e)
266- {
267- __fieldsContainer.workloadContinuerAndValueTaskSource.SetException(e);
268- {{ finalReturn }}
269- }
270- }
271- """ ;
272-
273- return smartStringBuilder
274- . Replace ( "$CoreImpl$" , coreImpl ) ;
275- }
276-
277201 protected bool TryGetAsyncMethodBuilderAttribute ( out string asyncMethodBuilderAttribute )
278202 {
279203 asyncMethodBuilderAttribute = string . Empty ;
@@ -322,32 +246,15 @@ protected static string GetFinalReturn(Type workloadCoreReturnType)
322246 ? "return;"
323247 : $ "return default({ finalReturnType . GetCorrectCSharpTypeName ( ) } );";
324248 }
325- }
326249
327- internal class AsyncEnumerableDeclarationsProvider ( BenchmarkCase benchmark , Type itemType , Type enumeratorType , Type moveNextAwaitableType ) : AsyncDeclarationsProvider ( benchmark )
328- {
329250 protected override SmartStringBuilder ReplaceCore ( SmartStringBuilder smartStringBuilder )
330251 {
252+ // Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods.
331253 int unrollFactor = Benchmark . Job . ResolveValue ( RunMode . UnrollFactorCharacteristic , EnvironmentResolver . Instance ) ;
332254 string passArguments = GetPassArgumentsDirect ( ) ;
333255 string workloadMethodCall = GetWorkloadMethodCall ( passArguments ) ;
334- string itemTypeName = itemType . GetCorrectCSharpTypeName ( ) ;
335- string enumerableTypeName = Descriptor . WorkloadMethod . ReturnType . GetCorrectCSharpTypeName ( ) ;
336- string enumeratorTypeName = enumeratorType . GetCorrectCSharpTypeName ( ) ;
337- // We hand-roll the `await foreach` desugaring (explicit GetAsyncEnumerator + while-loop) instead
338- // of using the C# `await foreach` keyword: that keeps the IL byte-for-byte aligned with
339- // AsyncEnumerableCoreEmitter, which doesn't wrap the iteration in the try/catch + wrap field
340- // pattern Roslyn emits for the keyword form.
341- string disposeAsyncCall = ResolveDisposeAsync ( ) is { } disposeAsyncMethod
342- ? $ "await enumerator.{ disposeAsyncMethod . Name } ();"
343- : string . Empty ;
344- // IAsyncEnumerable<T> has no GetAwaiter, so its own return type can't drive the WorkloadCore
345- // builder. Use MoveNextAsync's return type as the proxy and feed it through the same resolver
346- // as the awaitable path: any result the awaitable produces (typically `bool`) is discarded by
347- // `__StartWorkload`'s `await`, and `[AsyncCallerType]` / `[AsyncMethodBuilder]` overrides on
348- // the workload method still apply.
349256 bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute ( out var asyncMethodBuilderAttribute ) ;
350- Type workloadCoreReturnType = GetWorkloadCoreReturnType ( hasAsyncMethodBuilderAttribute , moveNextAwaitableType ) ;
257+ Type workloadCoreReturnType = GetWorkloadCoreReturnType ( hasAsyncMethodBuilderAttribute , WorkloadAwaitableReturnType ) ;
351258 string finalReturn = GetFinalReturn ( workloadCoreReturnType ) ;
352259 string coreImpl = $$ """
353260 private {{ CoreReturnType }} OverheadActionUnroll({{ CoreParameters }} )
@@ -383,7 +290,7 @@ private async void __StartWorkload()
383290 {
384291 await __WorkloadCore();
385292 }
386-
293+
387294 {{ asyncMethodBuilderAttribute }}
388295 private async {{ workloadCoreReturnType . GetCorrectCSharpTypeName ( ) }} __WorkloadCore()
389296 {
@@ -395,23 +302,12 @@ private async void __StartWorkload()
395302 }
396303 while (true)
397304 {
398- {{ itemTypeName }} lastItem = default({{ itemTypeName }} );
399305 {{ typeof ( StartedClock ) . GetCorrectCSharpTypeName ( ) }} startedClock = {{ typeof ( ClockExtensions ) . GetCorrectCSharpTypeName ( ) }} .Start(this.__fieldsContainer.clock);
400306 while (--this.__fieldsContainer.invokeCount >= 0)
401307 {
402- // Necessary because of error CS4004: Cannot await in an unsafe context
403- {{ enumerableTypeName }} enumerable;
404- unsafe { enumerable = {{ workloadMethodCall }} }
405- {{ enumeratorTypeName }} enumerator = enumerable.GetAsyncEnumerator();
406- while (await enumerator.MoveNextAsync())
407- {
408- lastItem = enumerator.Current;
409- }
410- {{ disposeAsyncCall }}
308+ {{ GetCallAndConsumeImpl ( workloadMethodCall ) }}
411309 }
412- {{ typeof ( ClockSpan ) . GetCorrectCSharpTypeName ( ) }} elapsed = startedClock.GetElapsed();
413- {{ typeof ( DeadCodeEliminationHelper ) . GetCorrectCSharpTypeName ( ) }} .KeepAliveWithoutBoxing(lastItem);
414- if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(elapsed))
310+ if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed()))
415311 {
416312 {{ finalReturn }}
417313 }
@@ -429,24 +325,51 @@ private async void __StartWorkload()
429325 . Replace ( "$CoreImpl$" , coreImpl ) ;
430326 }
431327
432- // Roslyn's `await foreach` resolution: prefer a public instance DisposeAsync with all-optional
433- // params whose awaiter's GetResult returns void; otherwise fall back to IAsyncDisposable.
434- // Returns null if neither shape matches, in which case the template skips the dispose call.
435- private MethodInfo ? ResolveDisposeAsync ( )
328+ protected abstract string GetCallAndConsumeImpl ( string workloadMethodCall ) ;
329+ }
330+
331+ internal class AsyncDeclarationsProvider ( BenchmarkCase benchmark , Type resultType ) : AsyncDeclarationsProviderBase ( benchmark )
332+ {
333+ protected override string GetCallAndConsumeImpl ( string workloadMethodCall )
436334 {
437- var disposeAsyncMethod = enumeratorType
438- . GetMethods ( BindingFlags . Public | BindingFlags . Instance )
439- . FirstOrDefault ( m => m . Name == nameof ( IAsyncDisposable . DisposeAsync )
440- && m . GetParameters ( ) . All ( p => p . IsOptional )
441- && m . ReturnType . GetMethod ( nameof ( Task . GetAwaiter ) , BindingFlags . Public | BindingFlags . Instance )
442- ? . ReturnType
443- . GetMethod ( nameof ( TaskAwaiter . GetResult ) , BindingFlags . Public | BindingFlags . Instance )
444- ? . ReturnType == typeof ( void ) ) ;
445- if ( disposeAsyncMethod is not null )
446- return disposeAsyncMethod ;
447- if ( typeof ( IAsyncDisposable ) . IsAssignableFrom ( enumeratorType ) )
448- return typeof ( IAsyncDisposable ) . GetMethod ( nameof ( IAsyncDisposable . DisposeAsync ) ) ;
449- return null ;
335+ string awaitStatement ;
336+ if ( resultType == typeof ( void ) )
337+ {
338+ awaitStatement = "await awaitable;" ;
339+ }
340+ else
341+ {
342+ var resultTypeName = resultType . GetCorrectCSharpTypeName ( ) ;
343+ awaitStatement = $ """
344+ { resultTypeName } result = await awaitable;
345+ { typeof ( DeadCodeEliminationHelper ) . GetCorrectCSharpTypeName ( ) } .KeepAliveWithoutBoxing<{ resultTypeName } >(in result);
346+ """ ;
347+ }
348+ return $$ """
349+ // Necessary because of error CS4004: Cannot await in an unsafe context
350+ {{ Descriptor . WorkloadMethod . ReturnType . GetCorrectCSharpTypeName ( ) }} awaitable;
351+ unsafe { awaitable = {{ workloadMethodCall }} }
352+ {{ awaitStatement }}
353+ """ ;
354+ }
355+ }
356+
357+ internal class AsyncEnumerableDeclarationsProvider ( BenchmarkCase benchmark , Type itemType , Type moveNextAwaitableType ) : AsyncDeclarationsProviderBase ( benchmark )
358+ {
359+ protected override Type WorkloadAwaitableReturnType => moveNextAwaitableType ;
360+
361+ protected override string GetCallAndConsumeImpl ( string workloadMethodCall )
362+ {
363+ string itemTypeName = itemType . GetCorrectCSharpTypeName ( ) ;
364+ return $$ """
365+ // Necessary because of error CS4004: Cannot await in an unsafe context
366+ {{ Descriptor . WorkloadMethod . ReturnType . GetCorrectCSharpTypeName ( ) }} enumerable;
367+ unsafe { enumerable = {{ workloadMethodCall }} }
368+ await foreach ({{ itemTypeName }} item in enumerable)
369+ {
370+ {{ typeof ( DeadCodeEliminationHelper ) . GetCorrectCSharpTypeName ( ) }} .KeepAliveWithoutBoxing<{{ itemTypeName }} >(in item);
371+ }
372+ """ ;
450373 }
451374 }
452375}
0 commit comments