Skip to content

Commit e40bd6a

Browse files
committed
Change type of async context parameter after state transform.
1 parent c0bb140 commit e40bd6a

57 files changed

Lines changed: 839 additions & 730 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 87 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -562,19 +562,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
562562
);
563563
}
564564

565-
/// Transforms the `body` of the coroutine applying the following transforms:
566-
///
567-
/// - Eliminates all the `get_context` calls that async lowering created.
568-
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
569-
///
570-
/// The `Local`s that have their types replaced are:
571-
/// - The `resume` argument itself.
572-
/// - The argument to `get_context`.
573-
/// - The yielded value of a `yield`.
574-
///
565+
/// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
575566
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
576567
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
577568
///
569+
/// The actual should be `&mut Context<'_>`. This performs the substitution:
570+
/// - create a new local `_r` of type `ResumeTy`;
571+
/// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
572+
/// - let all the code use `_r` instead of `_2`.
573+
///
578574
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
579575
/// but rather directly use `&mut Context<'_>`, however that would currently
580576
/// lead to higher-kinded lifetime errors.
@@ -586,93 +582,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
586582
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
587583
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
588584
let context_mut_ref = Ty::new_task_context(tcx);
585+
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
586+
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
587+
ty::GenericArgs::empty(),
588+
body.typing_env(tcx),
589+
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
590+
);
589591

590-
// replace the type of the `resume` argument
591-
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
592+
// Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
593+
// and set `CTX_ARG: &mut Context<'_>`.
594+
let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
595+
body.local_decls.swap(CTX_ARG, resume_local);
596+
RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);
592597

593-
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
598+
// Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
599+
// Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
600+
// at the function entry to make the bridge.
601+
let source_info = SourceInfo::outermost(body.span);
602+
let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
603+
let nonnull_rhs =
604+
Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
605+
let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
606+
let resume_rhs = Rvalue::Aggregate(
607+
Box::new(AggregateKind::Adt(
608+
resume_ty_def_id,
609+
VariantIdx::ZERO,
610+
ty::GenericArgs::empty(),
611+
None,
612+
None,
613+
)),
614+
indexvec![Operand::Move(nonnull_local.into())],
615+
);
616+
let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
617+
body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
618+
0..0,
619+
[Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
620+
);
621+
}
594622

595-
for bb in body.basic_blocks.indices() {
596-
let bb_data = &body[bb];
623+
/// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
624+
/// Both types are just a single pointer, but liveness analysis does not know that and
625+
/// supposes that the operand and the destination are live at the same time.
626+
/// Forcibly inline those calls to avoid this.
627+
fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
628+
let context_mut_ref = Ty::new_task_context(tcx);
629+
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
630+
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
631+
ty::GenericArgs::empty(),
632+
body.typing_env(tcx),
633+
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
634+
);
635+
636+
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
637+
for bb_data in body.basic_blocks.as_mut().iter_mut() {
597638
if bb_data.is_cleanup {
598639
continue;
599640
}
600641

601-
match &bb_data.terminator().kind {
602-
TerminatorKind::Call { func, .. } => {
603-
let func_ty = func.ty(body, tcx);
604-
if let ty::FnDef(def_id, _) = *func_ty.kind()
605-
&& def_id == get_context_def_id
606-
{
607-
let local = eliminate_get_context_call(&mut body[bb]);
608-
replace_resume_ty_local(tcx, body, local, context_mut_ref);
609-
}
610-
}
611-
TerminatorKind::Yield { resume_arg, .. } => {
612-
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
613-
}
614-
_ => {}
642+
let terminator = bb_data.terminator_mut();
643+
if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
644+
&& let func_ty = func.ty(&body.local_decls, tcx)
645+
&& let ty::FnDef(def_id, _) = *func_ty.kind()
646+
&& def_id == get_context_def_id
647+
&& let [arg] = &**args
648+
&& let Some(place) = arg.node.place()
649+
{
650+
let arg =
651+
Rvalue::Cast(
652+
CastKind::Transmute,
653+
Operand::Copy(place.project_deeper(
654+
&[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
655+
tcx,
656+
)),
657+
context_mut_ref,
658+
);
659+
let assign = Statement::new(
660+
terminator.source_info,
661+
StatementKind::Assign(Box::new((*destination, arg))),
662+
);
663+
terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
664+
bb_data.statements.push(assign);
615665
}
616666
}
617667
}
618668

619-
fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
620-
let terminator = bb_data.terminator.take().unwrap();
621-
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
622-
bug!();
623-
};
624-
let [arg] = *Box::try_from(args).unwrap();
625-
let local = arg.node.place().unwrap().local;
626-
627-
let arg = Rvalue::Use(arg.node, WithRetag::Yes);
628-
let assign =
629-
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
630-
bb_data.statements.push(assign);
631-
bb_data.terminator = Some(Terminator {
632-
source_info: terminator.source_info,
633-
kind: TerminatorKind::Goto { target: target.unwrap() },
634-
});
635-
local
636-
}
637-
638-
#[cfg_attr(not(debug_assertions), allow(unused))]
639-
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
640-
fn replace_resume_ty_local<'tcx>(
641-
tcx: TyCtxt<'tcx>,
642-
body: &mut Body<'tcx>,
643-
local: Local,
644-
context_mut_ref: Ty<'tcx>,
645-
) {
646-
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
647-
// We have to replace the `ResumeTy` that is used for type and borrow checking
648-
// with `&mut Context<'_>` in MIR.
649-
#[cfg(debug_assertions)]
650-
{
651-
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
652-
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
653-
assert_eq!(*resume_ty_adt, expected_adt);
654-
} else {
655-
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
656-
};
657-
}
658-
}
659-
660-
/// Transforms the `body` of the coroutine applying the following transform:
661-
///
662-
/// - Remove the `resume` argument.
663-
///
664-
/// Ideally the async lowering would not add the `resume` argument.
665-
///
666-
/// The async lowering step and the type / lifetime inference / checking are
667-
/// still using the `resume` argument for the time being. After this transform,
668-
/// the coroutine body doesn't have the `resume` argument.
669-
fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
670-
// This leaves the local representing the `resume` argument in place,
671-
// but turns it into a regular local variable. This is cheaper than
672-
// adjusting all local references in the body after removing it.
673-
body.arg_count = 1;
674-
}
675-
676669
struct LivenessInfo {
677670
/// Which locals are live across any suspension point.
678671
saved_locals: CoroutineSavedLocals,
@@ -1293,6 +1286,10 @@ fn create_coroutine_resume_function<'tcx>(
12931286
// Run derefer to fix Derefs that are not in the first place
12941287
deref_finder(tcx, body, false);
12951288

1289+
if transform.coroutine_kind.is_async_desugaring() {
1290+
transform_async_context(tcx, body);
1291+
}
1292+
12961293
if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
12971294
dumper.dump_mir(body);
12981295
}
@@ -1508,12 +1505,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15081505
// (finally in open_drop_for_tuple) before async drop expansion.
15091506
// Async drops, produced by this drop elaboration, will be expanded,
15101507
// and corresponding futures kept in layout.
1511-
let coroutine_is_async = coroutine_kind.is_async_desugaring();
15121508
let has_async_drops = has_async_drops(body);
15131509

1514-
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1515-
if coroutine_is_async {
1516-
transform_async_context(tcx, body);
1510+
if coroutine_kind.is_async_desugaring() {
1511+
eliminate_get_context_calls(tcx, body);
15171512
}
15181513

15191514
let always_live_locals = always_storage_live_locals(body);
@@ -1580,13 +1575,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15801575
}),
15811576
);
15821577

1583-
// Update our MIR struct to reflect the changes we've made
1584-
body.arg_count = 2; // self, resume arg
1585-
body.spread_arg = None;
1586-
15871578
// Remove the context argument within generator bodies.
15881579
if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1589-
transform_gen_context(body);
1580+
body.arg_count = 1;
15901581
}
15911582

15921583
// The original arguments to the function are no longer arguments, mark them as such.
@@ -1633,7 +1624,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
16331624
body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
16341625

16351626
// For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1636-
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
1627+
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
16371628
body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
16381629
}
16391630

compiler/rustc_mir_transform/src/coroutine/drop.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
308308
// Run derefer to fix Derefs that are not in the first place
309309
deref_finder(tcx, &mut body, false);
310310

311+
if transform.coroutine_kind.is_async_desugaring() {
312+
transform_async_context(tcx, &mut body);
313+
}
314+
311315
if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_async", &body) {
312316
dumper.dump_mir(&body);
313317
}
@@ -320,6 +324,7 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
320324
pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
321325
tcx: TyCtxt<'tcx>,
322326
body: &Body<'tcx>,
327+
coroutine_kind: CoroutineKind,
323328
) -> Body<'tcx> {
324329
let mut body = body.clone();
325330
// Take the coroutine info out of the body, since the drop shim is
@@ -357,6 +362,10 @@ pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
357362
// Run derefer to fix Derefs that are not in the first place
358363
deref_finder(tcx, &mut body, false);
359364

365+
if coroutine_kind.is_async_desugaring() {
366+
transform_async_context(tcx, &mut body);
367+
}
368+
360369
if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_proxy_async", &body) {
361370
dumper.dump_mir(&body);
362371
}

tests/mir-opt/coroutine/async_await.a-{closure#0}.StateTransform.diff

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
- fn a::{closure#0}(_1: {async fn body of a()}, _2: std::future::ResumeTy) -> ()
55
- yields ()
66
- {
7+
- debug _task_context => _2;
8+
- let mut _0: ();
79
+ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) -> Poll<()> {
810
+ coroutine layout {
911
+ variant_fields = {
@@ -13,16 +15,19 @@
1315
+ }
1416
+ storage_conflicts = BitMatrix(0x0) {}
1517
+ }
16-
debug _task_context => _2;
17-
- let mut _0: ();
18+
+ debug _task_context => _6;
1819
+ let mut _0: std::task::Poll<()>;
1920
+ let mut _3: ();
2021
+ let mut _4: u32;
2122
+ let mut _5: &mut {async fn body of a()};
23+
+ let mut _6: std::future::ResumeTy;
24+
+ let mut _7: std::ptr::NonNull<std::task::Context<'_>>;
2225

2326
bb0: {
2427
- _0 = const ();
2528
- drop(_1) -> [return: bb1, unwind: bb2];
29+
+ _7 = move _2 as std::ptr::NonNull<std::task::Context<'_>> (Transmute);
30+
+ _6 = std::future::ResumeTy(move _7);
2631
+ _5 = copy (_1.0: &mut {async fn body of a()});
2732
+ _4 = discriminant((*_5));
2833
+ switchInt(move _4) -> [0: bb5, 1: bb3, otherwise: bb4];

0 commit comments

Comments
 (0)