Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal enum ContinuationFlags
ContinueOnThreadPool = 1 << 0,
ContinueOnCapturedSynchronizationContext = 1 << 1,
ContinueOnCapturedTaskScheduler = 1 << 2,
AllContinueFlags = ContinueOnCapturedSynchronizationContext | ContinueOnThreadPool | ContinueOnCapturedTaskScheduler,

// The flags encode where in the continuation various members are stored.
// If the encoded index is 0, it means no such member is present.
Expand Down Expand Up @@ -414,15 +415,6 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
Continuation headContinuation = sentinelContinuation.Next!;
sentinelContinuation.Next = null;

// Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
// These never have special continuation context handling.
const ContinuationFlags continueFlags =
ContinuationFlags.ContinueOnCapturedSynchronizationContext |
ContinuationFlags.ContinueOnThreadPool |
ContinuationFlags.ContinueOnCapturedTaskScheduler;

Debug.Assert((headContinuation.Flags & continueFlags) == 0);

SetContinuationState(headContinuation);

try
Expand Down Expand Up @@ -457,8 +449,8 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
// Since we see a VTS notifier, something was directly or indirectly
// awaiting an async thunk for a ValueTask-returning method.
// That can only happen in nontransparent/user code.
Continuation nextUserContinuation = headContinuation.Next!;
while ((nextUserContinuation.Flags & continueFlags) == 0 && nextUserContinuation.Next != null)
Continuation nextUserContinuation = headContinuation;
while ((nextUserContinuation.Flags & ContinuationFlags.AllContinueFlags) == 0 && nextUserContinuation.Next != null)
{
nextUserContinuation = nextUserContinuation.Next;
}
Expand All @@ -475,7 +467,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}

// Clear continuation flags, so that continuation runs transparently
nextUserContinuation.Flags &= ~continueFlags;
nextUserContinuation.Flags &= ~ContinuationFlags.AllContinueFlags;
valueTaskSourceNotifier.OnCompleted(s_runContinuationAction, this, configFlags);
}
else
Expand Down Expand Up @@ -534,6 +526,12 @@ private unsafe void DispatchContinuations()
}
}

Continuation headContinuation = MoveContinuationState();
if ((headContinuation.Flags & ContinuationFlags.AllContinueFlags) != 0 && QueueContinuationFollowUpActionIfNecessary(headContinuation))
{
return;
}

RuntimeAsyncStackState stackState = default;

ref RuntimeAsyncAwaitState awaitState = ref t_runtimeAsyncAwaitState;
Expand All @@ -543,7 +541,7 @@ private unsafe void DispatchContinuations()

AsyncDispatcherInfo asyncDispatcherInfo;
asyncDispatcherInfo.Next = refDispatcherInfo;
asyncDispatcherInfo.NextContinuation = MoveContinuationState();
asyncDispatcherInfo.NextContinuation = headContinuation;
refDispatcherInfo = &asyncDispatcherInfo;

while (true)
Expand Down Expand Up @@ -633,6 +631,12 @@ private unsafe void DispatchContinuations()
[StackTraceHidden]
private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags flags)
{
Continuation headContinuation = MoveContinuationState();
if ((headContinuation.Flags & ContinuationFlags.AllContinueFlags) != 0 && QueueContinuationFollowUpActionIfNecessary(headContinuation))
{
return;
}

Comment thread
jakobbotsch marked this conversation as resolved.
RuntimeAsyncStackState stackState = default;

ref RuntimeAsyncAwaitState awaitState = ref t_runtimeAsyncAwaitState;
Expand All @@ -642,7 +646,7 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags

AsyncDispatcherInfo asyncDispatcherInfo;
asyncDispatcherInfo.Next = refDispatcherInfo;
asyncDispatcherInfo.NextContinuation = MoveContinuationState();
asyncDispatcherInfo.NextContinuation = headContinuation;
refDispatcherInfo = &asyncDispatcherInfo;

RuntimeAsyncInstrumentationHelpers.ResumeRuntimeAsyncContext(this, ref asyncDispatcherInfo, flags);
Expand Down Expand Up @@ -759,7 +763,9 @@ private unsafe void InstrumentedDispatchContinuations(AsyncInstrumentation.Flags

private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuation)
{
if ((continuation.Flags & ContinuationFlags.ContinueOnThreadPool) != 0)
ContinuationFlags flags = continuation.Flags;
continuation.Flags &= ~ContinuationFlags.AllContinueFlags;
if ((flags & ContinuationFlags.ContinueOnThreadPool) != 0)
{
SynchronizationContext? ctx = Thread.CurrentThreadAssumedInitialized._synchronizationContext;
if (ctx == null || ctx.GetType() == typeof(SynchronizationContext))
Expand All @@ -777,7 +783,7 @@ private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuatio
return true;
}

if ((continuation.Flags & ContinuationFlags.ContinueOnCapturedSynchronizationContext) != 0)
if ((flags & ContinuationFlags.ContinueOnCapturedSynchronizationContext) != 0)
{
object continuationContext = continuation.GetContinuationContext();
Debug.Assert(continuationContext is SynchronizationContext { });
Expand All @@ -803,7 +809,7 @@ private bool QueueContinuationFollowUpActionIfNecessary(Continuation continuatio
return true;
}

if ((continuation.Flags & ContinuationFlags.ContinueOnCapturedTaskScheduler) != 0)
if ((flags & ContinuationFlags.ContinueOnCapturedTaskScheduler) != 0)
{
object continuationContext = continuation.GetContinuationContext();
Debug.Assert(continuationContext is TaskScheduler { });
Expand Down
6 changes: 5 additions & 1 deletion src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4850,8 +4850,12 @@ class Compiler
IL_OFFSET rawILOffset);

void impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned prefixFlags, const DebugInfo& callDI);
void impAddAsyncArgsToInlinedCall(GenTreeCall* call);
bool impCurrentMethodIsKnownToPreserveSynchronizationContext();
bool impComputedIsAsyncThunk = false;
bool impIsAsyncThunk = false;

void impInsertAsyncContinuationForLdvirtftnCall(GenTreeCall* call);
void impInsertAsyncArgsForLdvirtftnCall(GenTreeCall* call);

CORINFO_CLASS_HANDLE impGetSpecialIntrinsicExactReturnType(GenTreeCall* call);

Expand Down
117 changes: 108 additions & 9 deletions src/coreclr/jit/importercalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ var_types Compiler::impImportCall(OPCODE opcode,

if (call->AsCall()->IsAsync())
{
impInsertAsyncContinuationForLdvirtftnCall(call->AsCall());
impInsertAsyncArgsForLdvirtftnCall(call->AsCall());
}

GenTree* thisPtr = impPopStack().val;
Expand Down Expand Up @@ -941,6 +941,11 @@ var_types Compiler::impImportCall(OPCODE opcode,
}
}

if (asyncContinuation != nullptr)
{
impAddAsyncArgsToInlinedCall(call->AsCall());
}

//-------------------------------------------------------------------------
// The "this" pointer

Expand Down Expand Up @@ -6970,13 +6975,39 @@ void Compiler::impCheckForPInvokeCall(
//
void Compiler::impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned prefixFlags, const DebugInfo& callDI)
{
AsyncCallInfo asyncInfo;

if (compIsForInlining())
{
compInlineResult->NoteFatal(InlineObservation::CALLEE_AWAIT);
return;
}
GenTreeCall* inlCall = impInlineInfo->iciCall;
JITDUMP("Call [%06u] being inlined has an async call [%06u]", dspTreeID(inlCall), dspTreeID(call));
assert(inlCall->IsAsync());
if (inlCall->GetAsyncInfo().ContinuationContextHandling != ContinuationContextHandling::None)
{
// Caller is relying on the async infrastructure to move the
// execution to the right place after returning from the callee.
JITDUMP(" and caller needs continuation context handling");

AsyncCallInfo asyncInfo;
if (!impCurrentMethodIsKnownToPreserveSynchronizationContext())
{
// May need to actually move the execution; we do not currently
// handle this case.
JITDUMP(" and callee may mutate synchronization context; cannot inline\n");
compInlineResult->NoteFatal(InlineObservation::CALLEE_AWAIT);
return;
}

JITDUMP(" but callee is known to preserve synchronization context; inlining anyway\n");
// These cases are selected only in a few cases:
assert((prefixFlags & PREFIX_IS_TASK_AWAIT) == 0);
assert((info.compMethodInfo->options & CORINFO_ASYNC_SAVE_CONTEXTS) == 0);
asyncInfo.ContinuationContextHandling = inlCall->GetAsyncInfo().ContinuationContextHandling;
}
Comment thread
jakobbotsch marked this conversation as resolved.
else
{
JITDUMP(" but inlining call has no continuation context handling\n");
}
}

unsigned newSourceTypes = ICorDebugInfo::ASYNC;
newSourceTypes |= (unsigned)callDI.GetLocation().GetSourceTypes() & ~ICorDebugInfo::CALL_INSTRUCTION;
Expand Down Expand Up @@ -7033,9 +7064,75 @@ void Compiler::impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned pref
}

//------------------------------------------------------------------------
// impInsertAsyncContinuationForLdvirtftnCall:
// Insert the async continuation argument for a call the EE asked to be
// performed via ldvirtftn.
// impAddAsyncArgsToInlinedCall:
// Add necessary async contexts to the specified inlined async call.
//
// Arguments:
// call - The async call
//
void Compiler::impAddAsyncArgsToInlinedCall(GenTreeCall* call)
{
if (!compIsForInlining())
{
return;
}

if ((info.compMethodInfo->options & CORINFO_ASYNC_SAVE_CONTEXTS) != 0)
{
// This async call in the inlinee needs its own context handling.
return;
}

GenTreeCall* inlCall = impInlineInfo->iciCall;
CallArg* execArg = inlCall->gtArgs.FindWellKnownArg(WellKnownArg::AsyncExecutionContext);
CallArg* syncArg = inlCall->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSynchronizationContext);
if ((execArg == nullptr) && (syncArg == nullptr))
{
// Caller also has no async contexts handling
return;
}

// We are inlining an async call that does not save contexts into a call
// that does. We currently allow this only in cases where the tail of the
// inlinee can run in the caller's context, and hence we propagate the
// caller's context here. It means we do not need to worry about switching
// into the caller's context when the inlinee is returning to the caller
// after the await.
assert(execArg->GetNode()->OperIs(GT_LCL_VAR) && syncArg->GetNode()->OperIs(GT_LCL_VAR));
JITDUMP("Inheriting contexts [%06u] and [%06u] from caller node\n", dspTreeID(execArg->GetNode()),
dspTreeID(syncArg->GetNode()));

GenTree* execNode = gtCloneExpr(execArg->GetNode());
GenTree* syncNode = gtCloneExpr(syncArg->GetNode());
call->gtArgs.PushFront(this, NewCallArg::Primitive(syncNode).WellKnown(WellKnownArg::AsyncSynchronizationContext));
call->gtArgs.PushFront(this, NewCallArg::Primitive(execNode).WellKnown(WellKnownArg::AsyncExecutionContext));
}

//------------------------------------------------------------------------
// impCurrentMethodIsKnownToPreserveSynchronizationContext:
// Check if the current method is known not to mutate Thread._synchronizationContext.
//
// Returns:
// True if so.
//
bool Compiler::impCurrentMethodIsKnownToPreserveSynchronizationContext()
{
if (!impComputedIsAsyncThunk)
{
bool otherVariantIsThunk;
CORINFO_METHOD_HANDLE otherVariant =
info.compCompHnd->getAsyncOtherVariant(info.compMethodHnd, &otherVariantIsThunk);
impIsAsyncThunk = (otherVariant != NO_METHOD_HANDLE) && !otherVariantIsThunk;
impComputedIsAsyncThunk = true;
}

return impIsAsyncThunk;
}

//------------------------------------------------------------------------
// impInsertAsyncArgsForLdvirtftnCall:
// Insert the async args for a call the EE asked to be performed via
// ldvirtftn.
//
// Arguments:
// call - The call
Expand All @@ -7044,7 +7141,7 @@ void Compiler::impSetupAsyncCall(GenTreeCall* call, OPCODE opcode, unsigned pref
// Should be called before the 'this' arg is inserted, but after other IL args
// have been inserted.
//
void Compiler::impInsertAsyncContinuationForLdvirtftnCall(GenTreeCall* call)
void Compiler::impInsertAsyncArgsForLdvirtftnCall(GenTreeCall* call)
{
assert(call->AsCall()->IsAsync());

Expand All @@ -7058,6 +7155,8 @@ void Compiler::impInsertAsyncContinuationForLdvirtftnCall(GenTreeCall* call)
call->AsCall()->gtArgs.PushBack(this, NewCallArg::Primitive(gtNewNull(), TYP_REF)
.WellKnown(WellKnownArg::AsyncContinuation));
}

impAddAsyncArgsToInlinedCall(call);
}

//------------------------------------------------------------------------
Expand Down
Loading