Skip to content

Commit 5b96154

Browse files
authored
Prevent dead-code-elimination for async results. (#3125)
Support ref struct async results. Fix await foreach codegen template. De-duplicate code.
1 parent 0ee60ea commit 5b96154

18 files changed

Lines changed: 530 additions & 562 deletions

src/BenchmarkDotNet.Analyzers/AsyncTypeShapes.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ public static bool IsAsyncEnumerable(ITypeSymbol type, INamedTypeSymbol? asyncEn
2525
return true;
2626
}
2727

28-
if (TryFindPatternGetAsyncEnumerator(type) is { } enumeratorType
29-
&& HasPatternMoveNextAsync(enumeratorType)
30-
&& HasPublicInstanceProperty(enumeratorType, "Current"))
28+
if (TryFindPatternGetAsyncEnumerator(type) is { } enumeratorType)
3129
{
32-
return true;
30+
// Roslyn commits to a found pattern `GetAsyncEnumerator` — if its return type doesn't
31+
// satisfy the await-foreach enumerator shape it reports an error instead of falling back
32+
// to `IAsyncEnumerable<T>`, even when the source also implements the interface. We mirror
33+
// that here so the analyzer's view of binding matches what `await foreach` would actually
34+
// accept.
35+
return HasPatternMoveNextAsync(enumeratorType)
36+
&& HasPublicInstanceProperty(enumeratorType, "Current");
3337
}
3438

3539
if (asyncEnumerableInterfaceSymbol != null)

src/BenchmarkDotNet/Code/CodeGenerator.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ private static DeclarationsProvider GetDeclarationsProvider(BenchmarkCase benchm
113113
{
114114
var method = benchmark.Descriptor.WorkloadMethod;
115115

116-
if (method.ReturnType.IsAwaitable())
116+
if (method.ReturnType.IsAwaitable(out var awaitableInfo))
117117
{
118-
return new AsyncDeclarationsProvider(benchmark);
118+
return new AsyncDeclarationsProvider(benchmark, awaitableInfo.ResultType);
119119
}
120120

121-
if (method.ReturnType.IsAsyncEnumerable(out var itemType, out var enumeratorType, out var moveNextAwaitableType))
121+
if (method.ReturnType.IsAsyncEnumerable(out var asyncEnumerableInfo))
122122
{
123-
return new AsyncEnumerableDeclarationsProvider(benchmark, itemType, enumeratorType, moveNextAwaitableType);
123+
return new AsyncEnumerableDeclarationsProvider(benchmark, asyncEnumerableInfo.ItemType, asyncEnumerableInfo.MoveNextAsyncMethod.ReturnType);
124124
}
125125

126126
if (method.ReturnType == typeof(void) && method.HasAttribute<AsyncStateMachineAttribute>())

src/BenchmarkDotNet/Code/DeclarationsProvider.cs

Lines changed: 57 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private void Replace(SmartStringBuilder smartStringBuilder, MethodInfo? method,
4747
userImpl = string.Empty;
4848
needsExplicitReturn = true;
4949
}
50-
else if (method.ReturnType.IsAwaitable())
50+
else if (method.ReturnType.IsAwaitable(out _))
5151
{
5252
modifier = "async";
5353
userImpl = $"await {GetMethodPrefix(method)}.{method.Name}();";
@@ -97,7 +97,7 @@ protected string GetPassArgumentsDirect()
9797
);
9898
}
9999

100-
internal class SyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
100+
internal sealed class SyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
101101
{
102102
public override string[] GetExtraFields() => [];
103103

@@ -175,8 +175,13 @@ private string GetPassArguments()
175175
);
176176
}
177177

178-
internal class AsyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
178+
internal abstract class AsyncDeclarationsProviderBase(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
179179
{
180+
// Type used to drive the WorkloadCore builder selection. For ordinary awaitables it's the workload
181+
// method's own return type, but `IAsyncEnumerable<T>` has no GetAwaiter, so AsyncEnumerableDeclarationsProvider
182+
// overrides this to expose the MoveNextAsync awaitable as a proxy.
183+
protected virtual Type WorkloadAwaitableReturnType => Descriptor.WorkloadMethod.ReturnType;
184+
180185
public override string[] GetExtraFields() =>
181186
[
182187
$"public {typeof(WorkloadValueTaskSource).GetCorrectCSharpTypeName()} workloadContinuerAndValueTaskSource;",
@@ -193,87 +198,6 @@ protected override string GetExtraGlobalSetupImpl()
193198
protected override string GetExtraGlobalCleanupImpl()
194199
=> "this.__fieldsContainer.workloadContinuerAndValueTaskSource.Complete();";
195200

196-
protected override SmartStringBuilder ReplaceCore(SmartStringBuilder smartStringBuilder)
197-
{
198-
// Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods.
199-
int unrollFactor = Benchmark.Job.ResolveValue(RunMode.UnrollFactorCharacteristic, EnvironmentResolver.Instance);
200-
string passArguments = GetPassArgumentsDirect();
201-
string workloadMethodCall = GetWorkloadMethodCall(passArguments);
202-
bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute(out var asyncMethodBuilderAttribute);
203-
Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, Descriptor.WorkloadMethod.ReturnType);
204-
string finalReturn = GetFinalReturn(workloadCoreReturnType);
205-
string coreImpl = $$"""
206-
private {{CoreReturnType}} OverheadActionUnroll({{CoreParameters}})
207-
{
208-
return this.OverheadActionNoUnroll(invokeCount * {{unrollFactor}}, clock);
209-
}
210-
211-
private {{CoreReturnType}} OverheadActionNoUnroll({{CoreParameters}})
212-
{
213-
{{StartClockSyncCode}}
214-
while (--invokeCount >= 0)
215-
{
216-
this.__Overhead({{passArguments}});
217-
}
218-
{{ReturnSyncCode}}
219-
}
220-
221-
private {{CoreReturnType}} WorkloadActionUnroll({{CoreParameters}})
222-
{
223-
return this.WorkloadActionNoUnroll(invokeCount * {{unrollFactor}}, clock);
224-
}
225-
226-
private {{CoreReturnType}} WorkloadActionNoUnroll({{CoreParameters}})
227-
{
228-
this.__fieldsContainer.invokeCount = invokeCount;
229-
this.__fieldsContainer.clock = clock;
230-
// The source is allocated and the workload loop started in __GlobalSetup,
231-
// so this hot path is branchless and allocation-free.
232-
return this.__fieldsContainer.workloadContinuerAndValueTaskSource.Continue();
233-
}
234-
235-
private async void __StartWorkload()
236-
{
237-
await __WorkloadCore();
238-
}
239-
240-
{{asyncMethodBuilderAttribute}}
241-
private async {{workloadCoreReturnType.GetCorrectCSharpTypeName()}} __WorkloadCore()
242-
{
243-
try
244-
{
245-
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.GetIsComplete())
246-
{
247-
{{finalReturn}}
248-
}
249-
while (true)
250-
{
251-
{{typeof(StartedClock).GetCorrectCSharpTypeName()}} startedClock = {{typeof(ClockExtensions).GetCorrectCSharpTypeName()}}.Start(this.__fieldsContainer.clock);
252-
while (--this.__fieldsContainer.invokeCount >= 0)
253-
{
254-
// Necessary because of error CS4004: Cannot await in an unsafe context
255-
{{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} awaitable;
256-
unsafe { awaitable = {{workloadMethodCall}} }
257-
await awaitable;
258-
}
259-
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed()))
260-
{
261-
{{finalReturn}}
262-
}
263-
}
264-
}
265-
catch (global::System.Exception e)
266-
{
267-
__fieldsContainer.workloadContinuerAndValueTaskSource.SetException(e);
268-
{{finalReturn}}
269-
}
270-
}
271-
""";
272-
273-
return smartStringBuilder
274-
.Replace("$CoreImpl$", coreImpl);
275-
}
276-
277201
protected bool TryGetAsyncMethodBuilderAttribute(out string asyncMethodBuilderAttribute)
278202
{
279203
asyncMethodBuilderAttribute = string.Empty;
@@ -322,32 +246,15 @@ protected static string GetFinalReturn(Type workloadCoreReturnType)
322246
? "return;"
323247
: $"return default({finalReturnType.GetCorrectCSharpTypeName()});";
324248
}
325-
}
326249

327-
internal class AsyncEnumerableDeclarationsProvider(BenchmarkCase benchmark, Type itemType, Type enumeratorType, Type moveNextAwaitableType) : AsyncDeclarationsProvider(benchmark)
328-
{
329250
protected override SmartStringBuilder ReplaceCore(SmartStringBuilder smartStringBuilder)
330251
{
252+
// Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods.
331253
int unrollFactor = Benchmark.Job.ResolveValue(RunMode.UnrollFactorCharacteristic, EnvironmentResolver.Instance);
332254
string passArguments = GetPassArgumentsDirect();
333255
string workloadMethodCall = GetWorkloadMethodCall(passArguments);
334-
string itemTypeName = itemType.GetCorrectCSharpTypeName();
335-
string enumerableTypeName = Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName();
336-
string enumeratorTypeName = enumeratorType.GetCorrectCSharpTypeName();
337-
// We hand-roll the `await foreach` desugaring (explicit GetAsyncEnumerator + while-loop) instead
338-
// of using the C# `await foreach` keyword: that keeps the IL byte-for-byte aligned with
339-
// AsyncEnumerableCoreEmitter, which doesn't wrap the iteration in the try/catch + wrap field
340-
// pattern Roslyn emits for the keyword form.
341-
string disposeAsyncCall = ResolveDisposeAsync() is { } disposeAsyncMethod
342-
? $"await enumerator.{disposeAsyncMethod.Name}();"
343-
: string.Empty;
344-
// IAsyncEnumerable<T> has no GetAwaiter, so its own return type can't drive the WorkloadCore
345-
// builder. Use MoveNextAsync's return type as the proxy and feed it through the same resolver
346-
// as the awaitable path: any result the awaitable produces (typically `bool`) is discarded by
347-
// `__StartWorkload`'s `await`, and `[AsyncCallerType]` / `[AsyncMethodBuilder]` overrides on
348-
// the workload method still apply.
349256
bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute(out var asyncMethodBuilderAttribute);
350-
Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, moveNextAwaitableType);
257+
Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, WorkloadAwaitableReturnType);
351258
string finalReturn = GetFinalReturn(workloadCoreReturnType);
352259
string coreImpl = $$"""
353260
private {{CoreReturnType}} OverheadActionUnroll({{CoreParameters}})
@@ -383,7 +290,7 @@ private async void __StartWorkload()
383290
{
384291
await __WorkloadCore();
385292
}
386-
293+
387294
{{asyncMethodBuilderAttribute}}
388295
private async {{workloadCoreReturnType.GetCorrectCSharpTypeName()}} __WorkloadCore()
389296
{
@@ -395,23 +302,12 @@ private async void __StartWorkload()
395302
}
396303
while (true)
397304
{
398-
{{itemTypeName}} lastItem = default({{itemTypeName}});
399305
{{typeof(StartedClock).GetCorrectCSharpTypeName()}} startedClock = {{typeof(ClockExtensions).GetCorrectCSharpTypeName()}}.Start(this.__fieldsContainer.clock);
400306
while (--this.__fieldsContainer.invokeCount >= 0)
401307
{
402-
// Necessary because of error CS4004: Cannot await in an unsafe context
403-
{{enumerableTypeName}} enumerable;
404-
unsafe { enumerable = {{workloadMethodCall}} }
405-
{{enumeratorTypeName}} enumerator = enumerable.GetAsyncEnumerator();
406-
while (await enumerator.MoveNextAsync())
407-
{
408-
lastItem = enumerator.Current;
409-
}
410-
{{disposeAsyncCall}}
308+
{{GetCallAndConsumeImpl(workloadMethodCall)}}
411309
}
412-
{{typeof(ClockSpan).GetCorrectCSharpTypeName()}} elapsed = startedClock.GetElapsed();
413-
{{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}}.KeepAliveWithoutBoxing(lastItem);
414-
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(elapsed))
310+
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed()))
415311
{
416312
{{finalReturn}}
417313
}
@@ -429,24 +325,51 @@ private async void __StartWorkload()
429325
.Replace("$CoreImpl$", coreImpl);
430326
}
431327

432-
// Roslyn's `await foreach` resolution: prefer a public instance DisposeAsync with all-optional
433-
// params whose awaiter's GetResult returns void; otherwise fall back to IAsyncDisposable.
434-
// Returns null if neither shape matches, in which case the template skips the dispose call.
435-
private MethodInfo? ResolveDisposeAsync()
328+
protected abstract string GetCallAndConsumeImpl(string workloadMethodCall);
329+
}
330+
331+
internal class AsyncDeclarationsProvider(BenchmarkCase benchmark, Type resultType) : AsyncDeclarationsProviderBase(benchmark)
332+
{
333+
protected override string GetCallAndConsumeImpl(string workloadMethodCall)
436334
{
437-
var disposeAsyncMethod = enumeratorType
438-
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
439-
.FirstOrDefault(m => m.Name == nameof(IAsyncDisposable.DisposeAsync)
440-
&& m.GetParameters().All(p => p.IsOptional)
441-
&& m.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
442-
?.ReturnType
443-
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
444-
?.ReturnType == typeof(void));
445-
if (disposeAsyncMethod is not null)
446-
return disposeAsyncMethod;
447-
if (typeof(IAsyncDisposable).IsAssignableFrom(enumeratorType))
448-
return typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync));
449-
return null;
335+
string awaitStatement;
336+
if (resultType == typeof(void))
337+
{
338+
awaitStatement = "await awaitable;";
339+
}
340+
else
341+
{
342+
var resultTypeName = resultType.GetCorrectCSharpTypeName();
343+
awaitStatement = $"""
344+
{resultTypeName} result = await awaitable;
345+
{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}.KeepAliveWithoutBoxing<{resultTypeName}>(in result);
346+
""";
347+
}
348+
return $$"""
349+
// Necessary because of error CS4004: Cannot await in an unsafe context
350+
{{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} awaitable;
351+
unsafe { awaitable = {{workloadMethodCall}} }
352+
{{awaitStatement}}
353+
""";
354+
}
355+
}
356+
357+
internal class AsyncEnumerableDeclarationsProvider(BenchmarkCase benchmark, Type itemType, Type moveNextAwaitableType) : AsyncDeclarationsProviderBase(benchmark)
358+
{
359+
protected override Type WorkloadAwaitableReturnType => moveNextAwaitableType;
360+
361+
protected override string GetCallAndConsumeImpl(string workloadMethodCall)
362+
{
363+
string itemTypeName = itemType.GetCorrectCSharpTypeName();
364+
return $$"""
365+
// Necessary because of error CS4004: Cannot await in an unsafe context
366+
{{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} enumerable;
367+
unsafe { enumerable = {{workloadMethodCall}} }
368+
await foreach ({{itemTypeName}} item in enumerable)
369+
{
370+
{{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}}.KeepAliveWithoutBoxing<{{itemTypeName}}>(in item);
371+
}
372+
""";
450373
}
451374
}
452375
}

src/BenchmarkDotNet/Engines/Consumer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public void Consume<T>(in T value)
146146
else if (default(T) == null && !typeof(T).IsValueType)
147147
Consume((object?)value);
148148
else
149-
DeadCodeEliminationHelper.KeepAliveWithoutBoxingReadonly(value); // non-primitive and nullable value types
149+
DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in value); // non-primitive and nullable value types
150150
}
151151
}
152152
}

0 commit comments

Comments
 (0)