Skip to content

Commit 39465ed

Browse files
committed
Auto merge of #157166 - cjgillot:retype-context-after, r=oli-obk
Change type of async context parameter after state transform. `Future::poll` expects a `&mut std::task::Context<'_>`. Meanwhile, async coroutines use `std::future::ResumeTy` as resume parameter. This is meant to workaround the limitations of borrowck, which cannot prove that coroutines implement `for<'a, 'b> CoroutineTrait<&'a mut Context<'b>>`. In the coroutine state transform, we need to change the signature from `ResumeTy` to the proper `&mut Context<'_>`. The current code attempts to find locals that have type `ResumeTy` to change their type. This is needlessly complex and relies on undocumented behaviour of the MIR builder. Instead, this PR proposes to replace the `ResumeTy` argument with a new local, with value `ResumeTy(transmute(context))`. Based on #156875.
2 parents d56483a + b7e26ac commit 39465ed

63 files changed

Lines changed: 854 additions & 733 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

compiler/rustc_mir_transform/src/coroutine/drop.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
308308
// Run derefer to fix Derefs that are not in the first place
309309
deref_finder(tcx, &mut body, false);
310310

311+
if transform.coroutine_kind.is_async_desugaring() {
312+
transform_async_context(tcx, &mut body);
313+
}
314+
311315
if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_async", &body) {
312316
dumper.dump_mir(&body);
313317
}
@@ -320,6 +324,7 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
320324
pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
321325
tcx: TyCtxt<'tcx>,
322326
body: &Body<'tcx>,
327+
coroutine_kind: CoroutineKind,
323328
) -> Body<'tcx> {
324329
let mut body = body.clone();
325330
// Take the coroutine info out of the body, since the drop shim is
@@ -357,6 +362,10 @@ pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
357362
// Run derefer to fix Derefs that are not in the first place
358363
deref_finder(tcx, &mut body, false);
359364

365+
if coroutine_kind.is_async_desugaring() {
366+
transform_async_context(tcx, &mut body);
367+
}
368+
360369
if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_proxy_async", &body) {
361370
dumper.dump_mir(&body);
362371
}

compiler/rustc_mir_transform/src/coroutine/mod.rs

Lines changed: 87 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -551,19 +551,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
551551
);
552552
}
553553

554-
/// Transforms the `body` of the coroutine applying the following transforms:
555-
///
556-
/// - Eliminates all the `get_context` calls that async lowering created.
557-
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
558-
///
559-
/// The `Local`s that have their types replaced are:
560-
/// - The `resume` argument itself.
561-
/// - The argument to `get_context`.
562-
/// - The yielded value of a `yield`.
563-
///
554+
/// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
564555
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
565556
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
566557
///
558+
/// The actual should be `&mut Context<'_>`. This performs the substitution:
559+
/// - create a new local `_r` of type `ResumeTy`;
560+
/// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
561+
/// - let all the code use `_r` instead of `_2`.
562+
///
567563
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
568564
/// but rather directly use `&mut Context<'_>`, however that would currently
569565
/// lead to higher-kinded lifetime errors.
@@ -575,93 +571,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
575571
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
576572
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
577573
let context_mut_ref = Ty::new_task_context(tcx);
574+
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
575+
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
576+
ty::GenericArgs::empty(),
577+
body.typing_env(tcx),
578+
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
579+
);
578580

579-
// replace the type of the `resume` argument
580-
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
581+
// Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
582+
// and set `CTX_ARG: &mut Context<'_>`.
583+
let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
584+
body.local_decls.swap(CTX_ARG, resume_local);
585+
RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);
581586

582-
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
587+
// Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
588+
// Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
589+
// at the function entry to make the bridge.
590+
let source_info = SourceInfo::outermost(body.span);
591+
let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
592+
let nonnull_rhs =
593+
Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
594+
let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
595+
let resume_rhs = Rvalue::Aggregate(
596+
Box::new(AggregateKind::Adt(
597+
resume_ty_def_id,
598+
VariantIdx::ZERO,
599+
ty::GenericArgs::empty(),
600+
None,
601+
None,
602+
)),
603+
indexvec![Operand::Move(nonnull_local.into())],
604+
);
605+
let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
606+
body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
607+
0..0,
608+
[Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
609+
);
610+
}
611+
612+
/// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
613+
/// Both types are just a single pointer, but liveness analysis does not know that and
614+
/// supposes that the operand and the destination are live at the same time.
615+
/// Forcibly inline those calls to avoid this.
616+
fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
617+
let context_mut_ref = Ty::new_task_context(tcx);
618+
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
619+
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
620+
ty::GenericArgs::empty(),
621+
body.typing_env(tcx),
622+
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
623+
);
583624

584-
for bb in body.basic_blocks.indices() {
585-
let bb_data = &body[bb];
625+
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
626+
for bb_data in body.basic_blocks.as_mut().iter_mut() {
586627
if bb_data.is_cleanup {
587628
continue;
588629
}
589630

590-
match &bb_data.terminator().kind {
591-
TerminatorKind::Call { func, .. } => {
592-
let func_ty = func.ty(body, tcx);
593-
if let ty::FnDef(def_id, _) = *func_ty.kind()
594-
&& def_id == get_context_def_id
595-
{
596-
let local = eliminate_get_context_call(&mut body[bb]);
597-
replace_resume_ty_local(tcx, body, local, context_mut_ref);
598-
}
599-
}
600-
TerminatorKind::Yield { resume_arg, .. } => {
601-
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
602-
}
603-
_ => {}
631+
let terminator = bb_data.terminator_mut();
632+
if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
633+
&& let func_ty = func.ty(&body.local_decls, tcx)
634+
&& let ty::FnDef(def_id, _) = *func_ty.kind()
635+
&& def_id == get_context_def_id
636+
&& let [arg] = &**args
637+
&& let Some(place) = arg.node.place()
638+
{
639+
let arg =
640+
Rvalue::Cast(
641+
CastKind::Transmute,
642+
Operand::Copy(place.project_deeper(
643+
&[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
644+
tcx,
645+
)),
646+
context_mut_ref,
647+
);
648+
let assign = Statement::new(
649+
terminator.source_info,
650+
StatementKind::Assign(Box::new((*destination, arg))),
651+
);
652+
terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
653+
bb_data.statements.push(assign);
604654
}
605655
}
606656
}
607657

608-
fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
609-
let terminator = bb_data.terminator.take().unwrap();
610-
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
611-
bug!();
612-
};
613-
let [arg] = *Box::try_from(args).unwrap();
614-
let local = arg.node.place().unwrap().local;
615-
616-
let arg = Rvalue::Use(arg.node, WithRetag::Yes);
617-
let assign =
618-
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
619-
bb_data.statements.push(assign);
620-
bb_data.terminator = Some(Terminator {
621-
source_info: terminator.source_info,
622-
kind: TerminatorKind::Goto { target: target.unwrap() },
623-
});
624-
local
625-
}
626-
627-
#[cfg_attr(not(debug_assertions), allow(unused))]
628-
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
629-
fn replace_resume_ty_local<'tcx>(
630-
tcx: TyCtxt<'tcx>,
631-
body: &mut Body<'tcx>,
632-
local: Local,
633-
context_mut_ref: Ty<'tcx>,
634-
) {
635-
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
636-
// We have to replace the `ResumeTy` that is used for type and borrow checking
637-
// with `&mut Context<'_>` in MIR.
638-
#[cfg(debug_assertions)]
639-
{
640-
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
641-
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
642-
assert_eq!(*resume_ty_adt, expected_adt);
643-
} else {
644-
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
645-
};
646-
}
647-
}
648-
649-
/// Transforms the `body` of the coroutine applying the following transform:
650-
///
651-
/// - Remove the `resume` argument.
652-
///
653-
/// Ideally the async lowering would not add the `resume` argument.
654-
///
655-
/// The async lowering step and the type / lifetime inference / checking are
656-
/// still using the `resume` argument for the time being. After this transform,
657-
/// the coroutine body doesn't have the `resume` argument.
658-
fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
659-
// This leaves the local representing the `resume` argument in place,
660-
// but turns it into a regular local variable. This is cheaper than
661-
// adjusting all local references in the body after removing it.
662-
body.arg_count = 1;
663-
}
664-
665658
/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and
666659
/// dispatches to blocks according to `cases`.
667660
///
@@ -883,6 +876,10 @@ fn create_coroutine_resume_function<'tcx>(
883876
// Run derefer to fix Derefs that are not in the first place
884877
deref_finder(tcx, body, false);
885878

879+
if transform.coroutine_kind.is_async_desugaring() {
880+
transform_async_context(tcx, body);
881+
}
882+
886883
if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
887884
dumper.dump_mir(body);
888885
}
@@ -1025,12 +1022,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
10251022
// (finally in open_drop_for_tuple) before async drop expansion.
10261023
// Async drops, produced by this drop elaboration, will be expanded,
10271024
// and corresponding futures kept in layout.
1028-
let coroutine_is_async = coroutine_kind.is_async_desugaring();
10291025
let has_async_drops = has_async_drops(body);
10301026

1031-
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1032-
if coroutine_is_async {
1033-
transform_async_context(tcx, body);
1027+
if coroutine_kind.is_async_desugaring() {
1028+
eliminate_get_context_calls(tcx, body);
10341029
}
10351030

10361031
let always_live_locals = always_storage_live_locals(body);
@@ -1097,13 +1092,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
10971092
}),
10981093
);
10991094

1100-
// Update our MIR struct to reflect the changes we've made
1101-
body.arg_count = 2; // self, resume arg
1102-
body.spread_arg = None;
1103-
11041095
// Remove the context argument within generator bodies.
11051096
if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1106-
transform_gen_context(body);
1097+
body.arg_count = 1;
11071098
}
11081099

11091100
// The original arguments to the function are no longer arguments, mark them as such.
@@ -1150,7 +1141,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
11501141
body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
11511142

11521143
// For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1153-
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
1144+
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
11541145
body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
11551146
}
11561147

tests/mir-opt/coroutine/async_await.a-{closure#0}.StateTransform.diff

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
- fn a::{closure#0}(_1: {async fn body of a()}, _2: std::future::ResumeTy) -> ()
55
- yields ()
66
- {
7+
- debug _task_context => _2;
8+
- let mut _0: ();
79
+ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) -> Poll<()> {
810
+ coroutine layout {
911
+ variant_fields = {
@@ -13,16 +15,19 @@
1315
+ }
1416
+ storage_conflicts = BitMatrix(0x0) {}
1517
+ }
16-
debug _task_context => _2;
17-
- let mut _0: ();
18+
+ debug _task_context => _6;
1819
+ let mut _0: std::task::Poll<()>;
1920
+ let mut _3: ();
2021
+ let mut _4: u32;
2122
+ let mut _5: &mut {async fn body of a()};
23+
+ let mut _6: std::future::ResumeTy;
24+
+ let mut _7: std::ptr::NonNull<std::task::Context<'_>>;
2225

2326
bb0: {
2427
- _0 = const ();
2528
- drop(_1) -> [return: bb1, unwind: bb2];
29+
+ _7 = move _2 as std::ptr::NonNull<std::task::Context<'_>> (Transmute);
30+
+ _6 = std::future::ResumeTy(move _7);
2631
+ _5 = copy (_1.0: &mut {async fn body of a()});
2732
+ _4 = discriminant((*_5));
2833
+ switchInt(move _4) -> [0: bb5, 1: bb3, otherwise: bb4];

0 commit comments

Comments
 (0)