Skip to content

Commit 8c16d8a

Browse files
committed
Change type of async context parameter after state transform.
1 parent b542677 commit 8c16d8a

10 files changed

Lines changed: 329 additions & 273 deletions

compiler/rustc_hir/src/hir.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,6 +2327,13 @@ impl CoroutineKind {
23272327
matches!(self, CoroutineKind::Desugared(_, CoroutineSource::Fn))
23282328
}
23292329

2330+
pub fn is_async_desugaring(self) -> bool {
2331+
matches!(
2332+
self,
2333+
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
2334+
)
2335+
}
2336+
23302337
pub fn to_plural_string(&self) -> String {
23312338
match self {
23322339
CoroutineKind::Desugared(d, CoroutineSource::Fn) => format!("{d:#}fn bodies"),

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 94 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -563,19 +563,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
563563
);
564564
}
565565

566-
/// Transforms the `body` of the coroutine applying the following transforms:
567-
///
568-
/// - Eliminates all the `get_context` calls that async lowering created.
569-
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
570-
///
571-
/// The `Local`s that have their types replaced are:
572-
/// - The `resume` argument itself.
573-
/// - The argument to `get_context`.
574-
/// - The yielded value of a `yield`.
575-
///
566+
/// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
576567
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
577568
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
578569
///
570+
/// The actual should be `&mut Context<'_>`. This performs the substitution:
571+
/// - create a new local `_r` of type `ResumeTy`;
572+
/// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
573+
/// - let all the code use `_r` instead of `_2`.
574+
///
579575
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
580576
/// but rather directly use `&mut Context<'_>`, however that would currently
581577
/// lead to higher-kinded lifetime errors.
@@ -584,95 +580,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
584580
/// The async lowering step and the type / lifetime inference / checking are
585581
/// still using the `ResumeTy` indirection for the time being, and that indirection
586582
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
587-
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
588-
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
583+
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
589584
let context_mut_ref = Ty::new_task_context(tcx);
585+
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
586+
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
587+
ty::GenericArgs::empty(),
588+
body.typing_env(tcx),
589+
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
590+
);
590591

591-
// replace the type of the `resume` argument
592-
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
592+
// Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
593+
// and set `CTX_ARG: &mut Context<'_>`.
594+
let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
595+
body.local_decls.swap(CTX_ARG, resume_local);
596+
RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);
593597

594-
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
598+
// Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
599+
// Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
600+
// at the function entry to make the bridge.
601+
let source_info = SourceInfo::outermost(body.span);
602+
let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
603+
let nonnull_rhs =
604+
Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
605+
let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
606+
let resume_rhs = Rvalue::Aggregate(
607+
Box::new(AggregateKind::Adt(
608+
resume_ty_def_id,
609+
VariantIdx::ZERO,
610+
ty::GenericArgs::empty(),
611+
None,
612+
None,
613+
)),
614+
indexvec![Operand::Move(nonnull_local.into())],
615+
);
616+
let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
617+
body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
618+
0..0,
619+
[Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
620+
);
621+
}
622+
623+
/// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
624+
/// Both types are just a single pointer, but liveness analysis does not know that and
625+
/// supposes that the operand and the destination are live at the same time.
626+
/// Forcibly inline those calls to avoid this.
627+
fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
628+
let context_mut_ref = Ty::new_task_context(tcx);
629+
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
630+
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
631+
ty::GenericArgs::empty(),
632+
body.typing_env(tcx),
633+
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
634+
);
595635

596-
for bb in body.basic_blocks.indices() {
597-
let bb_data = &body[bb];
636+
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
637+
for bb_data in body.basic_blocks.as_mut().iter_mut() {
598638
if bb_data.is_cleanup {
599639
continue;
600640
}
601641

602-
match &bb_data.terminator().kind {
603-
TerminatorKind::Call { func, .. } => {
604-
let func_ty = func.ty(body, tcx);
605-
if let ty::FnDef(def_id, _) = *func_ty.kind()
606-
&& def_id == get_context_def_id
607-
{
608-
let local = eliminate_get_context_call(&mut body[bb]);
609-
replace_resume_ty_local(tcx, body, local, context_mut_ref);
610-
}
611-
}
612-
TerminatorKind::Yield { resume_arg, .. } => {
613-
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
614-
}
615-
_ => {}
642+
let terminator = bb_data.terminator_mut();
643+
if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
644+
&& let func_ty = func.ty(&body.local_decls, tcx)
645+
&& let ty::FnDef(def_id, _) = *func_ty.kind()
646+
&& def_id == get_context_def_id
647+
&& let [arg] = &**args
648+
&& let Some(place) = arg.node.place()
649+
{
650+
let arg =
651+
Rvalue::Cast(
652+
CastKind::Transmute,
653+
Operand::Copy(place.project_deeper(
654+
&[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
655+
tcx,
656+
)),
657+
context_mut_ref,
658+
);
659+
let assign = Statement::new(
660+
terminator.source_info,
661+
StatementKind::Assign(Box::new((*destination, arg))),
662+
);
663+
terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
664+
bb_data.statements.push(assign);
616665
}
617666
}
618-
context_mut_ref
619-
}
620-
621-
fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
622-
let terminator = bb_data.terminator.take().unwrap();
623-
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
624-
bug!();
625-
};
626-
let [arg] = *Box::try_from(args).unwrap();
627-
let local = arg.node.place().unwrap().local;
628-
629-
let arg = Rvalue::Use(arg.node, WithRetag::Yes);
630-
let assign =
631-
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
632-
bb_data.statements.push(assign);
633-
bb_data.terminator = Some(Terminator {
634-
source_info: terminator.source_info,
635-
kind: TerminatorKind::Goto { target: target.unwrap() },
636-
});
637-
local
638-
}
639-
640-
#[cfg_attr(not(debug_assertions), allow(unused))]
641-
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
642-
fn replace_resume_ty_local<'tcx>(
643-
tcx: TyCtxt<'tcx>,
644-
body: &mut Body<'tcx>,
645-
local: Local,
646-
context_mut_ref: Ty<'tcx>,
647-
) {
648-
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
649-
// We have to replace the `ResumeTy` that is used for type and borrow checking
650-
// with `&mut Context<'_>` in MIR.
651-
#[cfg(debug_assertions)]
652-
{
653-
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
654-
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
655-
assert_eq!(*resume_ty_adt, expected_adt);
656-
} else {
657-
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
658-
};
659-
}
660-
}
661-
662-
/// Transforms the `body` of the coroutine applying the following transform:
663-
///
664-
/// - Remove the `resume` argument.
665-
///
666-
/// Ideally the async lowering would not add the `resume` argument.
667-
///
668-
/// The async lowering step and the type / lifetime inference / checking are
669-
/// still using the `resume` argument for the time being. After this transform,
670-
/// the coroutine body doesn't have the `resume` argument.
671-
fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
672-
// This leaves the local representing the `resume` argument in place,
673-
// but turns it into a regular local variable. This is cheaper than
674-
// adjusting all local references in the body after removing it.
675-
body.arg_count = 1;
676667
}
677668

678669
struct LivenessInfo {
@@ -1292,6 +1283,10 @@ fn create_coroutine_resume_function<'tcx>(
12921283

12931284
pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
12941285

1286+
if transform.coroutine_kind.is_async_desugaring() {
1287+
transform_async_context(tcx, body);
1288+
}
1289+
12951290
if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
12961291
dumper.dump_mir(body);
12971292
}
@@ -1507,18 +1502,15 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15071502
// (finally in open_drop_for_tuple) before async drop expansion.
15081503
// Async drops, produced by this drop elaboration, will be expanded,
15091504
// and corresponding futures kept in layout.
1510-
let has_async_drops = matches!(
1511-
coroutine_kind,
1512-
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
1513-
) && has_expandable_async_drops(tcx, body, coroutine_ty);
1505+
let has_async_drops = coroutine_kind.is_async_desugaring()
1506+
&& has_expandable_async_drops(tcx, body, coroutine_ty);
15141507

1515-
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1516-
if matches!(
1517-
coroutine_kind,
1518-
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
1519-
) {
1520-
let context_mut_ref = transform_async_context(tcx, body);
1521-
expand_async_drops(tcx, body, context_mut_ref, coroutine_kind, coroutine_ty);
1508+
if coroutine_kind.is_async_desugaring() {
1509+
eliminate_get_context_calls(tcx, body);
1510+
}
1511+
1512+
if has_async_drops {
1513+
expand_async_drops(tcx, body, coroutine_kind, coroutine_ty);
15221514

15231515
if let Some(dumper) = MirDumper::new(tcx, "coroutine_async_drop_expand", body) {
15241516
dumper.dump_mir(body);
@@ -1591,13 +1583,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15911583
}),
15921584
);
15931585

1594-
// Update our MIR struct to reflect the changes we've made
1595-
body.arg_count = 2; // self, resume arg
1596-
body.spread_arg = None;
1597-
15981586
// Remove the context argument within generator bodies.
15991587
if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1600-
transform_gen_context(body);
1588+
body.arg_count = 1;
16011589
}
16021590

16031591
// The original arguments to the function are no longer arguments, mark them as such.
@@ -1652,7 +1640,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
16521640
body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
16531641

16541642
// For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1655-
let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
1643+
let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
16561644
deref_finder(tcx, &mut proxy_shim, false);
16571645
body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
16581646
}

0 commit comments

Comments
 (0)