@@ -562,19 +562,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
562562 ) ;
563563}
564564
565- /// Transforms the `body` of the coroutine applying the following transforms:
566- ///
567- /// - Eliminates all the `get_context` calls that async lowering created.
568- /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
569- ///
570- /// The `Local`s that have their types replaced are:
571- /// - The `resume` argument itself.
572- /// - The argument to `get_context`.
573- /// - The yielded value of a `yield`.
574- ///
565+ /// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
575566/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
576567/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
577568///
569+ /// The actual should be `&mut Context<'_>`. This performs the substitution:
570+ /// - create a new local `_r` of type `ResumeTy`;
571+ /// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
572+ /// - let all the code use `_r` instead of `_2`.
573+ ///
578574/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
579575/// but rather directly use `&mut Context<'_>`, however that would currently
580576/// lead to higher-kinded lifetime errors.
@@ -586,93 +582,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
586582#[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
587583fn transform_async_context < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
588584 let context_mut_ref = Ty :: new_task_context ( tcx) ;
585+ let resume_ty_def_id = tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ;
586+ let resume_nonnull_ty = tcx. instantiate_and_normalize_erasing_regions (
587+ ty:: GenericArgs :: empty ( ) ,
588+ body. typing_env ( tcx) ,
589+ tcx. type_of ( tcx. adt_def ( resume_ty_def_id) . non_enum_variant ( ) . fields [ FieldIdx :: ZERO ] . did ) ,
590+ ) ;
589591
590- // replace the type of the `resume` argument
591- replace_resume_ty_local ( tcx, body, CTX_ARG , context_mut_ref) ;
592+ // Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
593+ // and set `CTX_ARG: &mut Context<'_>`.
594+ let resume_local = body. local_decls . push ( LocalDecl :: new ( context_mut_ref, body. span ) ) ;
595+ body. local_decls . swap ( CTX_ARG , resume_local) ;
596+ RenameLocalVisitor { from : CTX_ARG , to : resume_local, tcx } . visit_body ( body) ;
592597
593- let get_context_def_id = tcx. require_lang_item ( LangItem :: GetContext , body. span ) ;
598+ // Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
599+ // Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
600+ // at the function entry to make the bridge.
601+ let source_info = SourceInfo :: outermost ( body. span ) ;
602+ let nonnull_local = body. local_decls . push ( LocalDecl :: new ( resume_nonnull_ty, body. span ) ) ;
603+ let nonnull_rhs =
604+ Rvalue :: Cast ( CastKind :: Transmute , Operand :: Move ( CTX_ARG . into ( ) ) , resume_nonnull_ty) ;
605+ let nonnull_assign = StatementKind :: Assign ( Box :: new ( ( nonnull_local. into ( ) , nonnull_rhs) ) ) ;
606+ let resume_rhs = Rvalue :: Aggregate (
607+ Box :: new ( AggregateKind :: Adt (
608+ resume_ty_def_id,
609+ VariantIdx :: ZERO ,
610+ ty:: GenericArgs :: empty ( ) ,
611+ None ,
612+ None ,
613+ ) ) ,
614+ indexvec ! [ Operand :: Move ( nonnull_local. into( ) ) ] ,
615+ ) ;
616+ let resume_assign = StatementKind :: Assign ( Box :: new ( ( resume_local. into ( ) , resume_rhs) ) ) ;
617+ body. basic_blocks . as_mut_preserves_cfg ( ) [ START_BLOCK ] . statements . splice (
618+ 0 ..0 ,
619+ [ Statement :: new ( source_info, nonnull_assign) , Statement :: new ( source_info, resume_assign) ] ,
620+ ) ;
621+ }
594622
595- for bb in body. basic_blocks . indices ( ) {
596- let bb_data = & body[ bb] ;
623+ /// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
624+ /// Both types are just a single pointer, but liveness analysis does not know that and
625+ /// supposes that the operand and the destination are live at the same time.
626+ /// Forcibly inline those calls to avoid this.
627+ fn eliminate_get_context_calls < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
628+ let context_mut_ref = Ty :: new_task_context ( tcx) ;
629+ let resume_ty_def_id = tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ;
630+ let resume_nonnull_ty = tcx. instantiate_and_normalize_erasing_regions (
631+ ty:: GenericArgs :: empty ( ) ,
632+ body. typing_env ( tcx) ,
633+ tcx. type_of ( tcx. adt_def ( resume_ty_def_id) . non_enum_variant ( ) . fields [ FieldIdx :: ZERO ] . did ) ,
634+ ) ;
635+
636+ let get_context_def_id = tcx. require_lang_item ( LangItem :: GetContext , body. span ) ;
637+ for bb_data in body. basic_blocks . as_mut ( ) . iter_mut ( ) {
597638 if bb_data. is_cleanup {
598639 continue ;
599640 }
600641
601- match & bb_data. terminator ( ) . kind {
602- TerminatorKind :: Call { func, .. } => {
603- let func_ty = func. ty ( body, tcx) ;
604- if let ty:: FnDef ( def_id, _) = * func_ty. kind ( )
605- && def_id == get_context_def_id
606- {
607- let local = eliminate_get_context_call ( & mut body[ bb] ) ;
608- replace_resume_ty_local ( tcx, body, local, context_mut_ref) ;
609- }
610- }
611- TerminatorKind :: Yield { resume_arg, .. } => {
612- replace_resume_ty_local ( tcx, body, resume_arg. local , context_mut_ref) ;
613- }
614- _ => { }
642+ let terminator = bb_data. terminator_mut ( ) ;
643+ if let TerminatorKind :: Call { func, args, destination, target, .. } = & terminator. kind
644+ && let func_ty = func. ty ( & body. local_decls , tcx)
645+ && let ty:: FnDef ( def_id, _) = * func_ty. kind ( )
646+ && def_id == get_context_def_id
647+ && let [ arg] = & * * args
648+ && let Some ( place) = arg. node . place ( )
649+ {
650+ let arg =
651+ Rvalue :: Cast (
652+ CastKind :: Transmute ,
653+ Operand :: Copy ( place. project_deeper (
654+ & [ PlaceElem :: Field ( FieldIdx :: ZERO , resume_nonnull_ty) ] ,
655+ tcx,
656+ ) ) ,
657+ context_mut_ref,
658+ ) ;
659+ let assign = Statement :: new (
660+ terminator. source_info ,
661+ StatementKind :: Assign ( Box :: new ( ( * destination, arg) ) ) ,
662+ ) ;
663+ terminator. kind = TerminatorKind :: Goto { target : target. unwrap ( ) } ;
664+ bb_data. statements . push ( assign) ;
615665 }
616666 }
617667}
618668
619- fn eliminate_get_context_call < ' tcx > ( bb_data : & mut BasicBlockData < ' tcx > ) -> Local {
620- let terminator = bb_data. terminator . take ( ) . unwrap ( ) ;
621- let TerminatorKind :: Call { args, destination, target, .. } = terminator. kind else {
622- bug ! ( ) ;
623- } ;
624- let [ arg] = * Box :: try_from ( args) . unwrap ( ) ;
625- let local = arg. node . place ( ) . unwrap ( ) . local ;
626-
627- let arg = Rvalue :: Use ( arg. node , WithRetag :: Yes ) ;
628- let assign =
629- Statement :: new ( terminator. source_info , StatementKind :: Assign ( Box :: new ( ( destination, arg) ) ) ) ;
630- bb_data. statements . push ( assign) ;
631- bb_data. terminator = Some ( Terminator {
632- source_info : terminator. source_info ,
633- kind : TerminatorKind :: Goto { target : target. unwrap ( ) } ,
634- } ) ;
635- local
636- }
637-
638- #[ cfg_attr( not( debug_assertions) , allow( unused) ) ]
639- #[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
640- fn replace_resume_ty_local < ' tcx > (
641- tcx : TyCtxt < ' tcx > ,
642- body : & mut Body < ' tcx > ,
643- local : Local ,
644- context_mut_ref : Ty < ' tcx > ,
645- ) {
646- let local_ty = std:: mem:: replace ( & mut body. local_decls [ local] . ty , context_mut_ref) ;
647- // We have to replace the `ResumeTy` that is used for type and borrow checking
648- // with `&mut Context<'_>` in MIR.
649- #[ cfg( debug_assertions) ]
650- {
651- if let ty:: Adt ( resume_ty_adt, _) = local_ty. kind ( ) {
652- let expected_adt = tcx. adt_def ( tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ) ;
653- assert_eq ! ( * resume_ty_adt, expected_adt) ;
654- } else {
655- panic ! ( "expected `ResumeTy`, found `{:?}`" , local_ty) ;
656- } ;
657- }
658- }
659-
660- /// Transforms the `body` of the coroutine applying the following transform:
661- ///
662- /// - Remove the `resume` argument.
663- ///
664- /// Ideally the async lowering would not add the `resume` argument.
665- ///
666- /// The async lowering step and the type / lifetime inference / checking are
667- /// still using the `resume` argument for the time being. After this transform,
668- /// the coroutine body doesn't have the `resume` argument.
669- fn transform_gen_context < ' tcx > ( body : & mut Body < ' tcx > ) {
670- // This leaves the local representing the `resume` argument in place,
671- // but turns it into a regular local variable. This is cheaper than
672- // adjusting all local references in the body after removing it.
673- body. arg_count = 1 ;
674- }
675-
676669struct LivenessInfo {
677670 /// Which locals are live across any suspension point.
678671 saved_locals : CoroutineSavedLocals ,
@@ -1293,6 +1286,10 @@ fn create_coroutine_resume_function<'tcx>(
12931286 // Run derefer to fix Derefs that are not in the first place
12941287 deref_finder ( tcx, body, false ) ;
12951288
1289+ if transform. coroutine_kind . is_async_desugaring ( ) {
1290+ transform_async_context ( tcx, body) ;
1291+ }
1292+
12961293 if let Some ( dumper) = MirDumper :: new ( tcx, "coroutine_resume" , body) {
12971294 dumper. dump_mir ( body) ;
12981295 }
@@ -1508,12 +1505,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15081505 // (finally in open_drop_for_tuple) before async drop expansion.
15091506 // Async drops, produced by this drop elaboration, will be expanded,
15101507 // and corresponding futures kept in layout.
1511- let coroutine_is_async = coroutine_kind. is_async_desugaring ( ) ;
15121508 let has_async_drops = has_async_drops ( body) ;
15131509
1514- // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1515- if coroutine_is_async {
1516- transform_async_context ( tcx, body) ;
1510+ if coroutine_kind. is_async_desugaring ( ) {
1511+ eliminate_get_context_calls ( tcx, body) ;
15171512 }
15181513
15191514 let always_live_locals = always_storage_live_locals ( body) ;
@@ -1580,13 +1575,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15801575 } ) ,
15811576 ) ;
15821577
1583- // Update our MIR struct to reflect the changes we've made
1584- body. arg_count = 2 ; // self, resume arg
1585- body. spread_arg = None ;
1586-
15871578 // Remove the context argument within generator bodies.
15881579 if matches ! ( coroutine_kind, CoroutineKind :: Desugared ( CoroutineDesugaring :: Gen , _) ) {
1589- transform_gen_context ( body) ;
1580+ body. arg_count = 1 ;
15901581 }
15911582
15921583 // The original arguments to the function are no longer arguments, mark them as such.
@@ -1633,7 +1624,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
16331624 body. coroutine . as_mut ( ) . unwrap ( ) . coroutine_drop = Some ( drop_shim) ;
16341625
16351626 // For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1636- let proxy_shim = create_coroutine_drop_shim_proxy_async ( tcx, body) ;
1627+ let proxy_shim = create_coroutine_drop_shim_proxy_async ( tcx, body, coroutine_kind ) ;
16371628 body. coroutine . as_mut ( ) . unwrap ( ) . coroutine_drop_proxy_async = Some ( proxy_shim) ;
16381629 }
16391630
0 commit comments