diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index a874ddb5b58989..b2e3d4eb907bcf 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -22,6 +22,7 @@ internal enum ContinuationFlags ContinueOnThreadPool = 1 << 0, ContinueOnCapturedSynchronizationContext = 1 << 1, ContinueOnCapturedTaskScheduler = 1 << 2, + AllContinueFlags = ContinueOnCapturedSynchronizationContext | ContinueOnThreadPool | ContinueOnCapturedTaskScheduler, // The flags encode where in the continuation various members are stored. // If the encoded index is 0, it means no such member is present. @@ -414,15 +415,6 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) Continuation headContinuation = sentinelContinuation.Next!; sentinelContinuation.Next = null; - // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. - // These never have special continuation context handling. - const ContinuationFlags continueFlags = - ContinuationFlags.ContinueOnCapturedSynchronizationContext | - ContinuationFlags.ContinueOnThreadPool | - ContinuationFlags.ContinueOnCapturedTaskScheduler; - - Debug.Assert((headContinuation.Flags & continueFlags) == 0); - SetContinuationState(headContinuation); try @@ -457,8 +449,8 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) // Since we see a VTS notifier, something was directly or indirectly // awaiting an async thunk for a ValueTask-returning method. // That can only happen in nontransparent/user code. - Continuation nextUserContinuation = headContinuation.Next!; - while ((nextUserContinuation.Flags & continueFlags) == 0 && nextUserContinuation.Next != null) + Continuation nextUserContinuation = headContinuation; + while ((nextUserContinuation.Flags & ContinuationFlags.AllContinueFlags) == 0 && nextUserContinuation.Next != null) { nextUserContinuation = nextUserContinuation.Next; } @@ -475,7 +467,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) } // Clear continuation flags, so that continuation runs transparently - nextUserContinuation.Flags &= ~continueFlags; + nextUserContinuation.Flags &= ~ContinuationFlags.AllContinueFlags; valueTaskSourceNotifier.OnCompleted(s_runContinuationAction, this, configFlags); } else @@ -534,6 +526,12 @@ private unsafe void DispatchContinuations() } } + Continuation headContinuation = MoveContinuationState(); + if ((headContinuation.Flags & ContinuationFlags.AllContinueFlags) != 0 && QueueContinuationFollowUpActionIfNecessary(headContinuation)) + { + return; + } + RuntimeAsyncStackState stackState = default; ref RuntimeAsyncAwaitState awaitState = ref t_runtimeAsyncAwaitState; @@ -543,7 +541,7 @@ private unsafe void DispatchContinuations() AsyncDispatcherInfo asyncDispatcherInfo; asyncDispatcherInfo.Next = refDispatcherInfo; - asyncDispatcherInfo.NextContinuation = MoveContinuationState(); + asyncDispatcherInfo.NextContinuation = headContinuation; refDispatcherInfo = &asyncDispatcherInfo; while (true) @@ -633,6 +631,12 @@ private unsafe void DispatchContinuations() [StackTraceHidden] private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags flags) { + Continuation headContinuation = MoveContinuationState(); + if ((headContinuation.Flags & ContinuationFlags.AllContinueFlags) != 0 && QueueContinuationFollowUpActionIfNecessary(headContinuation)) + { + return; + } + RuntimeAsyncStackState stackState = default; ref RuntimeAsyncAwaitState awaitState = ref t_runtimeAsyncAwaitState; @@ -642,7 +646,7 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags AsyncDispatcherInfo asyncDispatcherInfo; asyncDispatcherInfo.Next = refDispatcherInfo; - asyncDispatcherInfo.NextContinuation = MoveContinuationState(); + asyncDispatcherInfo.NextContinuation = headContinuation; refDispatcherInfo = &asyncDispatcherInfo; RuntimeAsyncInstrumentationHelpers.ResumeRuntimeAsyncContext(this, ref asyncDispatcherInfo, flags); @@ -759,7 +763,9 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuation) { - if ((continuation.Flags & ContinuationFlags.ContinueOnThreadPool) != 0) + ContinuationFlags flags = continuation.Flags; + continuation.Flags &= ~ContinuationFlags.AllContinueFlags; + if ((flags & ContinuationFlags.ContinueOnThreadPool) != 0) { SynchronizationContext? ctx = Thread.CurrentThreadAssumedInitialized._synchronizationContext; if (ctx == null || ctx.GetType() == typeof(SynchronizationContext)) @@ -777,7 +783,7 @@ private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuatio return true; } - if ((continuation.Flags & ContinuationFlags.ContinueOnCapturedSynchronizationContext) != 0) + if ((flags & ContinuationFlags.ContinueOnCapturedSynchronizationContext) != 0) { object continuationContext = continuation.GetContinuationContext(); Debug.Assert(continuationContext is SynchronizationContext { }); @@ -803,7 +809,7 @@ private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuatio return true; } - if ((continuation.Flags & ContinuationFlags.ContinueOnCapturedTaskScheduler) != 0) + if ((flags & ContinuationFlags.ContinueOnCapturedTaskScheduler) != 0) { object continuationContext = continuation.GetContinuationContext(); Debug.Assert(continuationContext is TaskScheduler { }); diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index 35be79978cc0e3..2856d23d2d8a60 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -4850,8 +4850,12 @@ class Compiler IL_OFFSET rawILOffset); void impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned prefixFlags, const DebugInfo& callDI); + void impAddAsyncArgsToInlinedCall(GenTreeCall* call); + bool impCurrentMethodIsKnownToPreserveSynchronizationContext(); + bool impComputedIsAsyncThunk = false; + bool impIsAsyncThunk = false; - void impInsertAsyncContinuationForLdvirtftnCall(GenTreeCall* call); + void impInsertAsyncArgsForLdvirtftnCall(GenTreeCall* call); CORINFO_CLASS_HANDLE impGetSpecialIntrinsicExactReturnType(GenTreeCall* call); diff --git a/src/coreclr/jit/importercalls.cpp b/src/coreclr/jit/importercalls.cpp index 6ef07513ee4b10..9d896f0f08ea06 100644 --- a/src/coreclr/jit/importercalls.cpp +++ b/src/coreclr/jit/importercalls.cpp @@ -415,7 +415,7 @@ var_types Compiler::impImportCall(OPCODE opcode, if (call->AsCall()->IsAsync()) { - impInsertAsyncContinuationForLdvirtftnCall(call->AsCall()); + impInsertAsyncArgsForLdvirtftnCall(call->AsCall()); } GenTree* thisPtr = impPopStack().val; @@ -941,6 +941,11 @@ var_types Compiler::impImportCall(OPCODE opcode, } } + if (asyncContinuation != nullptr) + { + impAddAsyncArgsToInlinedCall(call->AsCall()); + } + //------------------------------------------------------------------------- // The "this" pointer @@ -6970,13 +6975,39 @@ void Compiler::impCheckForPInvokeCall( // void Compiler::impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned prefixFlags, const DebugInfo& callDI) { + AsyncCallInfo asyncInfo; + if (compIsForInlining()) { - compInlineResult->NoteFatal(InlineObservation::CALLEE_AWAIT); - return; - } + GenTreeCall* inlCall = impInlineInfo->iciCall; + JITDUMP("Call [%06u] being inlined has an async call [%06u]", dspTreeID(inlCall), dspTreeID(call)); + assert(inlCall->IsAsync()); + if (inlCall->GetAsyncInfo().ContinuationContextHandling != ContinuationContextHandling::None) + { + // Caller is relying on the async infrastructure to move the + // execution to the right place after returning from the callee. + JITDUMP(" and caller needs continuation context handling"); - AsyncCallInfo asyncInfo; + if (!impCurrentMethodIsKnownToPreserveSynchronizationContext()) + { + // May need to actually move the execution; we do not currently + // handle this case. + JITDUMP(" and callee may mutate synchronization context; cannot inline\n"); + compInlineResult->NoteFatal(InlineObservation::CALLEE_AWAIT); + return; + } + + JITDUMP(" but callee is known to preserve synchronization context; inlining anyway\n"); + // These cases are selected only in a few cases: + assert((prefixFlags & PREFIX_IS_TASK_AWAIT) == 0); + assert((info.compMethodInfo->options & CORINFO_ASYNC_SAVE_CONTEXTS) == 0); + asyncInfo.ContinuationContextHandling = inlCall->GetAsyncInfo().ContinuationContextHandling; + } + else + { + JITDUMP(" but inlining call has no continuation context handling\n"); + } + } unsigned newSourceTypes = ICorDebugInfo::ASYNC; newSourceTypes |= (unsigned)callDI.GetLocation().GetSourceTypes() & ~ICorDebugInfo::CALL_INSTRUCTION; @@ -7033,9 +7064,75 @@ void Compiler::impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned pref } //------------------------------------------------------------------------ -// impInsertAsyncContinuationForLdvirtftnCall: -// Insert the async continuation argument for a call the EE asked to be -// performed via ldvirtftn. +// impAddAsyncArgsToInlinedCall: +// Add necessary async contexts to the specified inlined async call. +// +// Arguments: +// call - The async call +// +void Compiler::impAddAsyncArgsToInlinedCall(GenTreeCall* call) +{ + if (!compIsForInlining()) + { + return; + } + + if ((info.compMethodInfo->options & CORINFO_ASYNC_SAVE_CONTEXTS) != 0) + { + // This async call in the inlinee needs its own context handling. + return; + } + + GenTreeCall* inlCall = impInlineInfo->iciCall; + CallArg* execArg = inlCall->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext); + CallArg* syncArg = inlCall->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext); + if ((execArg == nullptr) && (syncArg == nullptr)) + { + // Caller also has no async contexts handling + return; + } + + // We are inlining an async call that does not save contexts into a call + // that does. We currently allow this only in cases where the tail of the + // inlinee can run in the caller's context, and hence we propagate the + // caller's context here. It means we do not need to worry about switching + // into the caller's context when the inlinee is returning to the caller + // after the await. + assert(execArg->GetNode()->OperIs(GT_LCL_VAR) && syncArg->GetNode()->OperIs(GT_LCL_VAR)); + JITDUMP("Inheriting contexts [%06u] and [%06u] from caller node\n", dspTreeID(execArg->GetNode()), + dspTreeID(syncArg->GetNode())); + + GenTree* execNode = gtCloneExpr(execArg->GetNode()); + GenTree* syncNode = gtCloneExpr(syncArg->GetNode()); + call->gtArgs.PushFront(this, NewCallArg::Primitive(syncNode).WellKnown(WellKnownArg::AsyncSynchronizationContext)); + call->gtArgs.PushFront(this, NewCallArg::Primitive(execNode).WellKnown(WellKnownArg::AsyncExecutionContext)); +} + +//------------------------------------------------------------------------ +// impCurrentMethodIsKnownToPreserveSynchronizationContext: +// Check if the current method is known not to mutate Thread._synchronizationContext. +// +// Returns: +// True if so. +// +bool Compiler::impCurrentMethodIsKnownToPreserveSynchronizationContext() +{ + if (!impComputedIsAsyncThunk) + { + bool otherVariantIsThunk; + CORINFO_METHOD_HANDLE otherVariant = + info.compCompHnd->getAsyncOtherVariant(info.compMethodHnd, &otherVariantIsThunk); + impIsAsyncThunk = (otherVariant != NO_METHOD_HANDLE) && !otherVariantIsThunk; + impComputedIsAsyncThunk = true; + } + + return impIsAsyncThunk; +} + +//------------------------------------------------------------------------ +// impInsertAsyncArgsForLdvirtftnCall: +// Insert the async args for a call the EE asked to be performed via +// ldvirtftn. // // Arguments: // call - The call @@ -7044,7 +7141,7 @@ void Compiler::impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned pref // Should be called before the 'this' arg is inserted, but after other IL args // have been inserted. // -void Compiler::impInsertAsyncContinuationForLdvirtftnCall(GenTreeCall* call) +void Compiler::impInsertAsyncArgsForLdvirtftnCall(GenTreeCall* call) { assert(call->AsCall()->IsAsync()); @@ -7058,6 +7155,8 @@ void Compiler::impInsertAsyncContinuationForLdvirtftnCall(GenTreeCall* call) call->AsCall()->gtArgs.PushBack(this, NewCallArg::Primitive(gtNewNull(), TYP_REF) .WellKnown(WellKnownArg::AsyncContinuation)); } + + impAddAsyncArgsToInlinedCall(call); } //------------------------------------------------------------------------