@@ -551,19 +551,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
551551 ) ;
552552}
553553
554- /// Transforms the `body` of the coroutine applying the following transforms:
555- ///
556- /// - Eliminates all the `get_context` calls that async lowering created.
557- /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
558- ///
559- /// The `Local`s that have their types replaced are:
560- /// - The `resume` argument itself.
561- /// - The argument to `get_context`.
562- /// - The yielded value of a `yield`.
563- ///
554+ /// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
564555/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
565556/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
566557///
558+ /// The actual should be `&mut Context<'_>`. This performs the substitution:
559+ /// - create a new local `_r` of type `ResumeTy`;
560+ /// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
561+ /// - let all the code use `_r` instead of `_2`.
562+ ///
567563/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
568564/// but rather directly use `&mut Context<'_>`, however that would currently
569565/// lead to higher-kinded lifetime errors.
@@ -575,93 +571,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
575571#[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
576572fn transform_async_context < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
577573 let context_mut_ref = Ty :: new_task_context ( tcx) ;
574+ let resume_ty_def_id = tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ;
575+ let resume_nonnull_ty = tcx. instantiate_and_normalize_erasing_regions (
576+ ty:: GenericArgs :: empty ( ) ,
577+ body. typing_env ( tcx) ,
578+ tcx. type_of ( tcx. adt_def ( resume_ty_def_id) . non_enum_variant ( ) . fields [ FieldIdx :: ZERO ] . did ) ,
579+ ) ;
578580
579- // replace the type of the `resume` argument
580- replace_resume_ty_local ( tcx, body, CTX_ARG , context_mut_ref) ;
581+ // Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
582+ // and set `CTX_ARG: &mut Context<'_>`.
583+ let resume_local = body. local_decls . push ( LocalDecl :: new ( context_mut_ref, body. span ) ) ;
584+ body. local_decls . swap ( CTX_ARG , resume_local) ;
585+ RenameLocalVisitor { from : CTX_ARG , to : resume_local, tcx } . visit_body ( body) ;
581586
582- let get_context_def_id = tcx. require_lang_item ( LangItem :: GetContext , body. span ) ;
587+ // Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
588+ // Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
589+ // at the function entry to make the bridge.
590+ let source_info = SourceInfo :: outermost ( body. span ) ;
591+ let nonnull_local = body. local_decls . push ( LocalDecl :: new ( resume_nonnull_ty, body. span ) ) ;
592+ let nonnull_rhs =
593+ Rvalue :: Cast ( CastKind :: Transmute , Operand :: Move ( CTX_ARG . into ( ) ) , resume_nonnull_ty) ;
594+ let nonnull_assign = StatementKind :: Assign ( Box :: new ( ( nonnull_local. into ( ) , nonnull_rhs) ) ) ;
595+ let resume_rhs = Rvalue :: Aggregate (
596+ Box :: new ( AggregateKind :: Adt (
597+ resume_ty_def_id,
598+ VariantIdx :: ZERO ,
599+ ty:: GenericArgs :: empty ( ) ,
600+ None ,
601+ None ,
602+ ) ) ,
603+ indexvec ! [ Operand :: Move ( nonnull_local. into( ) ) ] ,
604+ ) ;
605+ let resume_assign = StatementKind :: Assign ( Box :: new ( ( resume_local. into ( ) , resume_rhs) ) ) ;
606+ body. basic_blocks . as_mut_preserves_cfg ( ) [ START_BLOCK ] . statements . splice (
607+ 0 ..0 ,
608+ [ Statement :: new ( source_info, nonnull_assign) , Statement :: new ( source_info, resume_assign) ] ,
609+ ) ;
610+ }
611+
612+ /// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
613+ /// Both types are just a single pointer, but liveness analysis does not know that and
614+ /// supposes that the operand and the destination are live at the same time.
615+ /// Forcibly inline those calls to avoid this.
616+ fn eliminate_get_context_calls < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
617+ let context_mut_ref = Ty :: new_task_context ( tcx) ;
618+ let resume_ty_def_id = tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ;
619+ let resume_nonnull_ty = tcx. instantiate_and_normalize_erasing_regions (
620+ ty:: GenericArgs :: empty ( ) ,
621+ body. typing_env ( tcx) ,
622+ tcx. type_of ( tcx. adt_def ( resume_ty_def_id) . non_enum_variant ( ) . fields [ FieldIdx :: ZERO ] . did ) ,
623+ ) ;
583624
584- for bb in body . basic_blocks . indices ( ) {
585- let bb_data = & body[ bb ] ;
625+ let get_context_def_id = tcx . require_lang_item ( LangItem :: GetContext , body . span ) ;
626+ for bb_data in body. basic_blocks . as_mut ( ) . iter_mut ( ) {
586627 if bb_data. is_cleanup {
587628 continue ;
588629 }
589630
590- match & bb_data. terminator ( ) . kind {
591- TerminatorKind :: Call { func, .. } => {
592- let func_ty = func. ty ( body, tcx) ;
593- if let ty:: FnDef ( def_id, _) = * func_ty. kind ( )
594- && def_id == get_context_def_id
595- {
596- let local = eliminate_get_context_call ( & mut body[ bb] ) ;
597- replace_resume_ty_local ( tcx, body, local, context_mut_ref) ;
598- }
599- }
600- TerminatorKind :: Yield { resume_arg, .. } => {
601- replace_resume_ty_local ( tcx, body, resume_arg. local , context_mut_ref) ;
602- }
603- _ => { }
631+ let terminator = bb_data. terminator_mut ( ) ;
632+ if let TerminatorKind :: Call { func, args, destination, target, .. } = & terminator. kind
633+ && let func_ty = func. ty ( & body. local_decls , tcx)
634+ && let ty:: FnDef ( def_id, _) = * func_ty. kind ( )
635+ && def_id == get_context_def_id
636+ && let [ arg] = & * * args
637+ && let Some ( place) = arg. node . place ( )
638+ {
639+ let arg =
640+ Rvalue :: Cast (
641+ CastKind :: Transmute ,
642+ Operand :: Copy ( place. project_deeper (
643+ & [ PlaceElem :: Field ( FieldIdx :: ZERO , resume_nonnull_ty) ] ,
644+ tcx,
645+ ) ) ,
646+ context_mut_ref,
647+ ) ;
648+ let assign = Statement :: new (
649+ terminator. source_info ,
650+ StatementKind :: Assign ( Box :: new ( ( * destination, arg) ) ) ,
651+ ) ;
652+ terminator. kind = TerminatorKind :: Goto { target : target. unwrap ( ) } ;
653+ bb_data. statements . push ( assign) ;
604654 }
605655 }
606656}
607657
608- fn eliminate_get_context_call < ' tcx > ( bb_data : & mut BasicBlockData < ' tcx > ) -> Local {
609- let terminator = bb_data. terminator . take ( ) . unwrap ( ) ;
610- let TerminatorKind :: Call { args, destination, target, .. } = terminator. kind else {
611- bug ! ( ) ;
612- } ;
613- let [ arg] = * Box :: try_from ( args) . unwrap ( ) ;
614- let local = arg. node . place ( ) . unwrap ( ) . local ;
615-
616- let arg = Rvalue :: Use ( arg. node , WithRetag :: Yes ) ;
617- let assign =
618- Statement :: new ( terminator. source_info , StatementKind :: Assign ( Box :: new ( ( destination, arg) ) ) ) ;
619- bb_data. statements . push ( assign) ;
620- bb_data. terminator = Some ( Terminator {
621- source_info : terminator. source_info ,
622- kind : TerminatorKind :: Goto { target : target. unwrap ( ) } ,
623- } ) ;
624- local
625- }
626-
627- #[ cfg_attr( not( debug_assertions) , allow( unused) ) ]
628- #[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
629- fn replace_resume_ty_local < ' tcx > (
630- tcx : TyCtxt < ' tcx > ,
631- body : & mut Body < ' tcx > ,
632- local : Local ,
633- context_mut_ref : Ty < ' tcx > ,
634- ) {
635- let local_ty = std:: mem:: replace ( & mut body. local_decls [ local] . ty , context_mut_ref) ;
636- // We have to replace the `ResumeTy` that is used for type and borrow checking
637- // with `&mut Context<'_>` in MIR.
638- #[ cfg( debug_assertions) ]
639- {
640- if let ty:: Adt ( resume_ty_adt, _) = local_ty. kind ( ) {
641- let expected_adt = tcx. adt_def ( tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ) ;
642- assert_eq ! ( * resume_ty_adt, expected_adt) ;
643- } else {
644- panic ! ( "expected `ResumeTy`, found `{:?}`" , local_ty) ;
645- } ;
646- }
647- }
648-
649- /// Transforms the `body` of the coroutine applying the following transform:
650- ///
651- /// - Remove the `resume` argument.
652- ///
653- /// Ideally the async lowering would not add the `resume` argument.
654- ///
655- /// The async lowering step and the type / lifetime inference / checking are
656- /// still using the `resume` argument for the time being. After this transform,
657- /// the coroutine body doesn't have the `resume` argument.
658- fn transform_gen_context < ' tcx > ( body : & mut Body < ' tcx > ) {
659- // This leaves the local representing the `resume` argument in place,
660- // but turns it into a regular local variable. This is cheaper than
661- // adjusting all local references in the body after removing it.
662- body. arg_count = 1 ;
663- }
664-
665658/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and
666659/// dispatches to blocks according to `cases`.
667660///
@@ -883,6 +876,10 @@ fn create_coroutine_resume_function<'tcx>(
883876 // Run derefer to fix Derefs that are not in the first place
884877 deref_finder ( tcx, body, false ) ;
885878
879+ if transform. coroutine_kind . is_async_desugaring ( ) {
880+ transform_async_context ( tcx, body) ;
881+ }
882+
886883 if let Some ( dumper) = MirDumper :: new ( tcx, "coroutine_resume" , body) {
887884 dumper. dump_mir ( body) ;
888885 }
@@ -1025,12 +1022,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
10251022 // (finally in open_drop_for_tuple) before async drop expansion.
10261023 // Async drops, produced by this drop elaboration, will be expanded,
10271024 // and corresponding futures kept in layout.
1028- let coroutine_is_async = coroutine_kind. is_async_desugaring ( ) ;
10291025 let has_async_drops = has_async_drops ( body) ;
10301026
1031- // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1032- if coroutine_is_async {
1033- transform_async_context ( tcx, body) ;
1027+ if coroutine_kind. is_async_desugaring ( ) {
1028+ eliminate_get_context_calls ( tcx, body) ;
10341029 }
10351030
10361031 let always_live_locals = always_storage_live_locals ( body) ;
@@ -1097,13 +1092,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
10971092 } ) ,
10981093 ) ;
10991094
1100- // Update our MIR struct to reflect the changes we've made
1101- body. arg_count = 2 ; // self, resume arg
1102- body. spread_arg = None ;
1103-
11041095 // Remove the context argument within generator bodies.
11051096 if matches ! ( coroutine_kind, CoroutineKind :: Desugared ( CoroutineDesugaring :: Gen , _) ) {
1106- transform_gen_context ( body) ;
1097+ body. arg_count = 1 ;
11071098 }
11081099
11091100 // The original arguments to the function are no longer arguments, mark them as such.
@@ -1150,7 +1141,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
11501141 body. coroutine . as_mut ( ) . unwrap ( ) . coroutine_drop = Some ( drop_shim) ;
11511142
11521143 // For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1153- let proxy_shim = create_coroutine_drop_shim_proxy_async ( tcx, body) ;
1144+ let proxy_shim = create_coroutine_drop_shim_proxy_async ( tcx, body, coroutine_kind ) ;
11541145 body. coroutine . as_mut ( ) . unwrap ( ) . coroutine_drop_proxy_async = Some ( proxy_shim) ;
11551146 }
11561147
0 commit comments