@@ -563,19 +563,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
563563 ) ;
564564}
565565
566- /// Transforms the `body` of the coroutine applying the following transforms:
567- ///
568- /// - Eliminates all the `get_context` calls that async lowering created.
569- /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
570- ///
571- /// The `Local`s that have their types replaced are:
572- /// - The `resume` argument itself.
573- /// - The argument to `get_context`.
574- /// - The yielded value of a `yield`.
575- ///
566+ /// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
576567/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
577568/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
578569///
570+ /// The actual should be `&mut Context<'_>`. This performs the substitution:
571+ /// - create a new local `_r` of type `ResumeTy`;
572+ /// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
573+ /// - let all the code use `_r` instead of `_2`.
574+ ///
579575/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
580576/// but rather directly use `&mut Context<'_>`, however that would currently
581577/// lead to higher-kinded lifetime errors.
@@ -584,95 +580,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
584580/// The async lowering step and the type / lifetime inference / checking are
585581/// still using the `ResumeTy` indirection for the time being, and that indirection
586582/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
587- #[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
588- fn transform_async_context < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) -> Ty < ' tcx > {
583+ fn transform_async_context < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
589584 let context_mut_ref = Ty :: new_task_context ( tcx) ;
585+ let resume_ty_def_id = tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ;
586+ let resume_nonnull_ty = tcx. instantiate_and_normalize_erasing_regions (
587+ ty:: GenericArgs :: empty ( ) ,
588+ body. typing_env ( tcx) ,
589+ tcx. type_of ( tcx. adt_def ( resume_ty_def_id) . non_enum_variant ( ) . fields [ FieldIdx :: ZERO ] . did ) ,
590+ ) ;
590591
591- // replace the type of the `resume` argument
592- replace_resume_ty_local ( tcx, body, CTX_ARG , context_mut_ref) ;
592+ // Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
593+ // and set `CTX_ARG: &mut Context<'_>`.
594+ let resume_local = body. local_decls . push ( LocalDecl :: new ( context_mut_ref, body. span ) ) ;
595+ body. local_decls . swap ( CTX_ARG , resume_local) ;
596+ RenameLocalVisitor { from : CTX_ARG , to : resume_local, tcx } . visit_body ( body) ;
593597
594- let get_context_def_id = tcx. require_lang_item ( LangItem :: GetContext , body. span ) ;
598+ // Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
599+ // Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
600+ // at the function entry to make the bridge.
601+ let source_info = SourceInfo :: outermost ( body. span ) ;
602+ let nonnull_local = body. local_decls . push ( LocalDecl :: new ( resume_nonnull_ty, body. span ) ) ;
603+ let nonnull_rhs =
604+ Rvalue :: Cast ( CastKind :: Transmute , Operand :: Move ( CTX_ARG . into ( ) ) , resume_nonnull_ty) ;
605+ let nonnull_assign = StatementKind :: Assign ( Box :: new ( ( nonnull_local. into ( ) , nonnull_rhs) ) ) ;
606+ let resume_rhs = Rvalue :: Aggregate (
607+ Box :: new ( AggregateKind :: Adt (
608+ resume_ty_def_id,
609+ VariantIdx :: ZERO ,
610+ ty:: GenericArgs :: empty ( ) ,
611+ None ,
612+ None ,
613+ ) ) ,
614+ indexvec ! [ Operand :: Move ( nonnull_local. into( ) ) ] ,
615+ ) ;
616+ let resume_assign = StatementKind :: Assign ( Box :: new ( ( resume_local. into ( ) , resume_rhs) ) ) ;
617+ body. basic_blocks . as_mut_preserves_cfg ( ) [ START_BLOCK ] . statements . splice (
618+ 0 ..0 ,
619+ [ Statement :: new ( source_info, nonnull_assign) , Statement :: new ( source_info, resume_assign) ] ,
620+ ) ;
621+ }
622+
623+ /// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
624+ /// Both types are just a single pointer, but liveness analysis does not know that and
625+ /// supposes that the operand and the destination are live at the same time.
626+ /// Forcibly inline those calls to avoid this.
627+ fn eliminate_get_context_calls < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
628+ let context_mut_ref = Ty :: new_task_context ( tcx) ;
629+ let resume_ty_def_id = tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ;
630+ let resume_nonnull_ty = tcx. instantiate_and_normalize_erasing_regions (
631+ ty:: GenericArgs :: empty ( ) ,
632+ body. typing_env ( tcx) ,
633+ tcx. type_of ( tcx. adt_def ( resume_ty_def_id) . non_enum_variant ( ) . fields [ FieldIdx :: ZERO ] . did ) ,
634+ ) ;
595635
596- for bb in body . basic_blocks . indices ( ) {
597- let bb_data = & body[ bb ] ;
636+ let get_context_def_id = tcx . require_lang_item ( LangItem :: GetContext , body . span ) ;
637+ for bb_data in body. basic_blocks . as_mut ( ) . iter_mut ( ) {
598638 if bb_data. is_cleanup {
599639 continue ;
600640 }
601641
602- match & bb_data. terminator ( ) . kind {
603- TerminatorKind :: Call { func, .. } => {
604- let func_ty = func. ty ( body, tcx) ;
605- if let ty:: FnDef ( def_id, _) = * func_ty. kind ( )
606- && def_id == get_context_def_id
607- {
608- let local = eliminate_get_context_call ( & mut body[ bb] ) ;
609- replace_resume_ty_local ( tcx, body, local, context_mut_ref) ;
610- }
611- }
612- TerminatorKind :: Yield { resume_arg, .. } => {
613- replace_resume_ty_local ( tcx, body, resume_arg. local , context_mut_ref) ;
614- }
615- _ => { }
642+ let terminator = bb_data. terminator_mut ( ) ;
643+ if let TerminatorKind :: Call { func, args, destination, target, .. } = & terminator. kind
644+ && let func_ty = func. ty ( & body. local_decls , tcx)
645+ && let ty:: FnDef ( def_id, _) = * func_ty. kind ( )
646+ && def_id == get_context_def_id
647+ && let [ arg] = & * * args
648+ && let Some ( place) = arg. node . place ( )
649+ {
650+ let arg =
651+ Rvalue :: Cast (
652+ CastKind :: Transmute ,
653+ Operand :: Copy ( place. project_deeper (
654+ & [ PlaceElem :: Field ( FieldIdx :: ZERO , resume_nonnull_ty) ] ,
655+ tcx,
656+ ) ) ,
657+ context_mut_ref,
658+ ) ;
659+ let assign = Statement :: new (
660+ terminator. source_info ,
661+ StatementKind :: Assign ( Box :: new ( ( * destination, arg) ) ) ,
662+ ) ;
663+ terminator. kind = TerminatorKind :: Goto { target : target. unwrap ( ) } ;
664+ bb_data. statements . push ( assign) ;
616665 }
617666 }
618- context_mut_ref
619- }
620-
621- fn eliminate_get_context_call < ' tcx > ( bb_data : & mut BasicBlockData < ' tcx > ) -> Local {
622- let terminator = bb_data. terminator . take ( ) . unwrap ( ) ;
623- let TerminatorKind :: Call { args, destination, target, .. } = terminator. kind else {
624- bug ! ( ) ;
625- } ;
626- let [ arg] = * Box :: try_from ( args) . unwrap ( ) ;
627- let local = arg. node . place ( ) . unwrap ( ) . local ;
628-
629- let arg = Rvalue :: Use ( arg. node , WithRetag :: Yes ) ;
630- let assign =
631- Statement :: new ( terminator. source_info , StatementKind :: Assign ( Box :: new ( ( destination, arg) ) ) ) ;
632- bb_data. statements . push ( assign) ;
633- bb_data. terminator = Some ( Terminator {
634- source_info : terminator. source_info ,
635- kind : TerminatorKind :: Goto { target : target. unwrap ( ) } ,
636- } ) ;
637- local
638- }
639-
640- #[ cfg_attr( not( debug_assertions) , allow( unused) ) ]
641- #[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
642- fn replace_resume_ty_local < ' tcx > (
643- tcx : TyCtxt < ' tcx > ,
644- body : & mut Body < ' tcx > ,
645- local : Local ,
646- context_mut_ref : Ty < ' tcx > ,
647- ) {
648- let local_ty = std:: mem:: replace ( & mut body. local_decls [ local] . ty , context_mut_ref) ;
649- // We have to replace the `ResumeTy` that is used for type and borrow checking
650- // with `&mut Context<'_>` in MIR.
651- #[ cfg( debug_assertions) ]
652- {
653- if let ty:: Adt ( resume_ty_adt, _) = local_ty. kind ( ) {
654- let expected_adt = tcx. adt_def ( tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ) ;
655- assert_eq ! ( * resume_ty_adt, expected_adt) ;
656- } else {
657- panic ! ( "expected `ResumeTy`, found `{:?}`" , local_ty) ;
658- } ;
659- }
660- }
661-
662- /// Transforms the `body` of the coroutine applying the following transform:
663- ///
664- /// - Remove the `resume` argument.
665- ///
666- /// Ideally the async lowering would not add the `resume` argument.
667- ///
668- /// The async lowering step and the type / lifetime inference / checking are
669- /// still using the `resume` argument for the time being. After this transform,
670- /// the coroutine body doesn't have the `resume` argument.
671- fn transform_gen_context < ' tcx > ( body : & mut Body < ' tcx > ) {
672- // This leaves the local representing the `resume` argument in place,
673- // but turns it into a regular local variable. This is cheaper than
674- // adjusting all local references in the body after removing it.
675- body. arg_count = 1 ;
676667}
677668
678669struct LivenessInfo {
@@ -1292,6 +1283,10 @@ fn create_coroutine_resume_function<'tcx>(
12921283
12931284 pm:: run_passes_no_validate ( tcx, body, & [ & abort_unwinding_calls:: AbortUnwindingCalls ] , None ) ;
12941285
1286+ if transform. coroutine_kind . is_async_desugaring ( ) {
1287+ transform_async_context ( tcx, body) ;
1288+ }
1289+
12951290 if let Some ( dumper) = MirDumper :: new ( tcx, "coroutine_resume" , body) {
12961291 dumper. dump_mir ( body) ;
12971292 }
@@ -1507,18 +1502,15 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15071502 // (finally in open_drop_for_tuple) before async drop expansion.
15081503 // Async drops, produced by this drop elaboration, will be expanded,
15091504 // and corresponding futures kept in layout.
1510- let has_async_drops = matches ! (
1511- coroutine_kind,
1512- CoroutineKind :: Desugared ( CoroutineDesugaring :: Async | CoroutineDesugaring :: AsyncGen , _)
1513- ) && has_expandable_async_drops ( tcx, body, coroutine_ty) ;
1505+ let has_async_drops = coroutine_kind. is_async_desugaring ( )
1506+ && has_expandable_async_drops ( tcx, body, coroutine_ty) ;
15141507
1515- // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1516- if matches ! (
1517- coroutine_kind,
1518- CoroutineKind :: Desugared ( CoroutineDesugaring :: Async | CoroutineDesugaring :: AsyncGen , _)
1519- ) {
1520- let context_mut_ref = transform_async_context ( tcx, body) ;
1521- expand_async_drops ( tcx, body, context_mut_ref, coroutine_kind, coroutine_ty) ;
1508+ if coroutine_kind. is_async_desugaring ( ) {
1509+ eliminate_get_context_calls ( tcx, body) ;
1510+ }
1511+
1512+ if has_async_drops {
1513+ expand_async_drops ( tcx, body, coroutine_kind, coroutine_ty) ;
15221514
15231515 if let Some ( dumper) = MirDumper :: new ( tcx, "coroutine_async_drop_expand" , body) {
15241516 dumper. dump_mir ( body) ;
@@ -1591,13 +1583,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15911583 } ) ,
15921584 ) ;
15931585
1594- // Update our MIR struct to reflect the changes we've made
1595- body. arg_count = 2 ; // self, resume arg
1596- body. spread_arg = None ;
1597-
15981586 // Remove the context argument within generator bodies.
15991587 if matches ! ( coroutine_kind, CoroutineKind :: Desugared ( CoroutineDesugaring :: Gen , _) ) {
1600- transform_gen_context ( body) ;
1588+ body. arg_count = 1 ;
16011589 }
16021590
16031591 // The original arguments to the function are no longer arguments, mark them as such.
@@ -1652,7 +1640,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
16521640 body. coroutine . as_mut ( ) . unwrap ( ) . coroutine_drop = Some ( drop_shim) ;
16531641
16541642 // For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1655- let mut proxy_shim = create_coroutine_drop_shim_proxy_async ( tcx, body) ;
1643+ let mut proxy_shim = create_coroutine_drop_shim_proxy_async ( tcx, body, coroutine_kind ) ;
16561644 deref_finder ( tcx, & mut proxy_shim, false ) ;
16571645 body. coroutine . as_mut ( ) . unwrap ( ) . coroutine_drop_proxy_async = Some ( proxy_shim) ;
16581646 }
0 commit comments