Skip to content

Commit a28cc76

Browse files
committed
Auto merge of #157309 - cjgillot:coroutine-hir-desugar, r=oli-obk
Desugar async blocks in HIR instead of MIR Implements MCP rust-lang/compiler-team#997 Based on #157166 In the current implementation, `gen`/`async`/`async gen` blocks and closures have type `Coroutine(..)` and `CoroutineClosure(..)`. Those types implement `Iterator`, `Future` or `AsyncIterator` depending on the initial desugaring. This creates a lot of complexity: - trait solvers must check which kind of coroutine each time; - MIR StateTransform needs to fixup types depending on the coroutine kind. I propose to change the desugaring for coroutines to: - `gen { .. }` becomes `CoroutineIterator::from_coroutine(#[coroutine] { .. })`; - `async { .. }` becomes `CoroutineFuture::from_coroutine(#[coroutine] { .. })`; - `async gen { .. }` becomes `CoroutineAsyncIterator::from_coroutine(#[coroutine] { .. })`. This way, all coroutines implement `std::ops::Coroutine` and `core` is responsible for translating this to user-friendly traits. All the complexity is pushed to error-reporting code, which is not soundness-critical. Coroutine closures are a little more complex, as we need to keep the `CoroutineClosure` type for borrow-checking. Main design point: I create two methods on `TyCtxt` that are meant to do the back-and-forth between wrapped and unwrapped coroutines. `coroutine_desugared_type` wraps a coroutine inside the adapter struct. `try_unwrap_desugared_coroutine` unwraps it. r? @oli-obk cc @lcnr @RalfJung cc @estebank as I modify quite a lot of diagnostic code
2 parents c55fad5 + 12ac72b commit a28cc76

221 files changed

Lines changed: 4029 additions & 4267 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

compiler/rustc_ast_lowering/src/expr.rs

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@ use rustc_ast::node_id::NodeMap;
66
use rustc_ast::*;
77
use rustc_data_structures::stack::ensure_sufficient_stack;
88
use rustc_errors::msg;
9-
use rustc_hir as hir;
109
use rustc_hir::def::{DefKind, Res};
11-
use rustc_hir::{HirId, Target, find_attr};
10+
use rustc_hir::{self as hir, HirId, LangItem, Target, find_attr};
1211
use rustc_middle::span_bug;
1312
use rustc_middle::ty::TyCtxt;
1413
use rustc_session::errors::report_lit_error;
@@ -358,9 +357,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
358357
GenBlockKind::Gen => hir::CoroutineDesugaring::Gen,
359358
GenBlockKind::AsyncGen => hir::CoroutineDesugaring::AsyncGen,
360359
};
361-
self.make_desugared_coroutine_expr(
360+
return self.make_desugared_coroutine_expr(
362361
*capture_clause,
363362
e.id,
363+
expr_hir_id,
364364
None,
365365
*decl_span,
366366
e.span,
@@ -374,7 +374,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
374374
expr
375375
})
376376
},
377-
)
377+
);
378378
}
379379
ExprKind::Block(blk, opt_label) => {
380380
// Different from loops, label of block resolves to block id rather than
@@ -815,26 +815,29 @@ impl<'hir> LoweringContext<'_, 'hir> {
815815
&mut self,
816816
capture_clause: CaptureBy,
817817
closure_node_id: NodeId,
818+
closure_hir_id: HirId,
818819
return_ty: Option<hir::FnRetTy<'hir>>,
819820
fn_decl_span: Span,
820821
span: Span,
821822
desugaring_kind: hir::CoroutineDesugaring,
822823
coroutine_source: hir::CoroutineSource,
823824
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
824-
) -> hir::ExprKind<'hir> {
825+
) -> hir::Expr<'hir> {
825826
let closure_def_id = self.local_def_id(closure_node_id);
826827
let coroutine_kind = hir::CoroutineKind::Desugared(desugaring_kind, coroutine_source);
827828

829+
let span = self.lower_span(span);
830+
let unstable_span = self.mark_span_with_reason(
831+
DesugaringKind::Async,
832+
span,
833+
Some(Arc::clone(&self.allow_gen_future)),
834+
);
835+
828836
// The `async` desugaring takes a resume argument and maintains a `task_context`,
829837
// whereas a generator does not.
830838
let (inputs, params, task_context): (&[_], &[_], _) = match desugaring_kind {
831839
hir::CoroutineDesugaring::Async | hir::CoroutineDesugaring::AsyncGen => {
832840
// Resume argument type: `ResumeTy`
833-
let unstable_span = self.mark_span_with_reason(
834-
DesugaringKind::Async,
835-
self.lower_span(span),
836-
Some(Arc::clone(&self.allow_gen_future)),
837-
);
838841
let resume_ty =
839842
self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span, None);
840843
let input_ty = hir::Ty {
@@ -850,21 +853,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
850853
Ident::with_dummy_span(sym::_task_context),
851854
hir::BindingMode::MUT,
852855
);
853-
let param = hir::Param {
854-
hir_id: self.next_id(),
855-
pat,
856-
ty_span: self.lower_span(span),
857-
span: self.lower_span(span),
858-
};
856+
let param = hir::Param { hir_id: self.next_id(), pat, ty_span: span, span };
859857
let params = arena_vec![self; param];
860858

861859
(inputs, params, Some(task_context_hid))
862860
}
863861
hir::CoroutineDesugaring::Gen => (&[], &[], None),
864862
};
865863

866-
let output =
867-
return_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span)));
864+
let output = return_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(span));
868865

869866
let fn_decl = self.arena.alloc(hir::FnDecl {
870867
inputs,
@@ -886,7 +883,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
886883
});
887884

888885
// `static |<_task_context?>| -> <return_ty> { <body> }`:
889-
hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
886+
let coroutine_closure = hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
890887
def_id: closure_def_id,
891888
binder: hir::ClosureBinder::Default,
892889
capture_clause: self.lower_capture_clause(capture_clause),
@@ -898,7 +895,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
898895
kind: hir::ClosureKind::Coroutine(coroutine_kind),
899896
constness: hir::Constness::NotConst,
900897
explicit_captures: &[],
901-
}))
898+
}));
899+
let coroutine_closure = hir::Expr { hir_id: closure_hir_id, kind: coroutine_closure, span };
900+
901+
let from_coroutine = match desugaring_kind {
902+
hir::CoroutineDesugaring::Async => LangItem::FutureFromCoroutine,
903+
hir::CoroutineDesugaring::AsyncGen => LangItem::AsyncIteratorFromCoroutine,
904+
hir::CoroutineDesugaring::Gen => LangItem::IterFromCoroutine,
905+
};
906+
self.expr_call_lang_item_fn_mut(
907+
unstable_span,
908+
from_coroutine,
909+
arena_vec![self; coroutine_closure],
910+
)
902911
}
903912

904913
/// Forwards a possible `#[track_caller]` annotation from `outer_hir_id` to
@@ -951,19 +960,20 @@ impl<'hir> LoweringContext<'_, 'hir> {
951960
/// }
952961
/// ```
953962
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
963+
let full_span = expr.span.to(await_kw_span);
954964
let expr = self.arena.alloc(self.lower_expr_mut(expr));
955-
self.make_lowered_await(await_kw_span, expr, FutureKind::Future)
965+
self.make_lowered_await(await_kw_span, full_span, expr, FutureKind::Future)
956966
}
957967

958968
/// Takes an expr that has already been lowered and generates a desugared await loop around it
959969
fn make_lowered_await(
960970
&mut self,
961971
await_kw_span: Span,
972+
// Pass the span separately, as `expr.span` may be a desugaring.
973+
full_span: Span,
962974
expr: &'hir hir::Expr<'hir>,
963975
await_kind: FutureKind,
964976
) -> hir::ExprKind<'hir> {
965-
let full_span = expr.span.to(await_kw_span);
966-
967977
let is_async_gen = match self.coroutine_kind {
968978
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => false,
969979
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => true,
@@ -974,19 +984,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
974984
// is not accidentally orphaned.
975985
let stmt_id = self.next_id();
976986
let expr_err = self.expr(
977-
expr.span,
987+
full_span,
978988
hir::ExprKind::Err(self.dcx().emit_err(AwaitOnlyInAsyncFnAndBlocks {
979989
await_kw_span,
980990
item_span: self.current_item,
981991
})),
982992
);
983993
return hir::ExprKind::Block(
984994
self.block_all(
985-
expr.span,
995+
full_span,
986996
arena_vec![self; hir::Stmt {
987997
hir_id: stmt_id,
988998
kind: hir::StmtKind::Semi(expr),
989-
span: expr.span,
999+
span: full_span,
9901000
}],
9911001
Some(self.arena.alloc(expr_err)),
9921002
),
@@ -1645,18 +1655,54 @@ impl<'hir> LoweringContext<'_, 'hir> {
16451655
.emit();
16461656
}
16471657

1648-
let is_async_gen = match self.coroutine_kind {
1649-
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) => false,
1650-
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => true,
1651-
Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => {
1652-
// Lower to a block `{ EXPR; <error> }` so that the awaited expr
1653-
// is not accidentally orphaned.
1658+
let Some(coroutine_kind) = self.coroutine_kind else {
1659+
let suggestion = self.current_item.map(|s| s.shrink_to_lo());
1660+
self.dcx().emit_err(YieldInClosure { span, suggestion });
1661+
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine(Movability::Movable));
1662+
return hir::ExprKind::Yield(yielded, hir::YieldSource::Yield);
1663+
};
1664+
1665+
match coroutine_kind {
1666+
// Raw and Gen coroutines, nothing to do.
1667+
hir::CoroutineKind::Coroutine(_)
1668+
| hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
1669+
hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
1670+
}
1671+
// `yield $expr` is transformed into `task_context = yield async_gen_ready($expr)`.
1672+
// This ensures that we store our resumed `ResumeContext` correctly, and also that
1673+
// the apparent value of the `yield` expression is `()`.
1674+
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
1675+
let desugar_span = self.mark_span_with_reason(
1676+
DesugaringKind::Async,
1677+
span,
1678+
Some(Arc::clone(&self.allow_async_gen)),
1679+
);
1680+
let wrapped_yielded = self.expr_call_lang_item_fn(
1681+
desugar_span,
1682+
hir::LangItem::AsyncGenReady,
1683+
std::slice::from_ref(yielded),
1684+
);
1685+
let yield_expr = self.arena.alloc(
1686+
self.expr(span, hir::ExprKind::Yield(wrapped_yielded, hir::YieldSource::Yield)),
1687+
);
1688+
1689+
let Some(task_context_hid) = self.task_context else {
1690+
unreachable!("use of `await` outside of an async context.");
1691+
};
1692+
let task_context_ident = Ident::with_dummy_span(sym::_task_context);
1693+
let lhs = self.expr_ident(desugar_span, task_context_ident, task_context_hid);
1694+
1695+
hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span))
1696+
}
1697+
// Lower to a block `{ EXPR; <error> }` so that the awaited expr
1698+
// is not accidentally orphaned.
1699+
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => {
16541700
let stmt_id = self.next_id();
16551701
let expr_err = self.expr(
16561702
yielded.span,
16571703
hir::ExprKind::Err(self.dcx().emit_err(AsyncCoroutinesNotSupported { span })),
16581704
);
1659-
return hir::ExprKind::Block(
1705+
hir::ExprKind::Block(
16601706
self.block_all(
16611707
yielded.span,
16621708
arena_vec![self; hir::Stmt {
@@ -1667,45 +1713,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
16671713
Some(self.arena.alloc(expr_err)),
16681714
),
16691715
None,
1670-
);
1671-
}
1672-
Some(hir::CoroutineKind::Coroutine(_)) => false,
1673-
None => {
1674-
let suggestion = self.current_item.map(|s| s.shrink_to_lo());
1675-
self.dcx().emit_err(YieldInClosure { span, suggestion });
1676-
self.coroutine_kind = Some(hir::CoroutineKind::Coroutine(Movability::Movable));
1677-
1678-
false
1716+
)
16791717
}
1680-
};
1681-
1682-
if is_async_gen {
1683-
// `yield $expr` is transformed into `task_context = yield async_gen_ready($expr)`.
1684-
// This ensures that we store our resumed `ResumeContext` correctly, and also that
1685-
// the apparent value of the `yield` expression is `()`.
1686-
let desugar_span = self.mark_span_with_reason(
1687-
DesugaringKind::Async,
1688-
span,
1689-
Some(Arc::clone(&self.allow_async_gen)),
1690-
);
1691-
let wrapped_yielded = self.expr_call_lang_item_fn(
1692-
desugar_span,
1693-
hir::LangItem::AsyncGenReady,
1694-
std::slice::from_ref(yielded),
1695-
);
1696-
let yield_expr = self.arena.alloc(
1697-
self.expr(span, hir::ExprKind::Yield(wrapped_yielded, hir::YieldSource::Yield)),
1698-
);
1699-
1700-
let Some(task_context_hid) = self.task_context else {
1701-
unreachable!("use of `await` outside of an async context.");
1702-
};
1703-
let task_context_ident = Ident::with_dummy_span(sym::_task_context);
1704-
let lhs = self.expr_ident(desugar_span, task_context_ident, task_context_hid);
1705-
1706-
hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span))
1707-
} else {
1708-
hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
17091718
}
17101719
}
17111720

@@ -1800,7 +1809,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
18001809
));
18011810
// `unsafe { ... }`
18021811
let iter = self.arena.alloc(self.expr_unsafe(head_span, iter));
1803-
let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator);
1812+
let kind = self.make_lowered_await(
1813+
head_span,
1814+
head_span,
1815+
iter,
1816+
FutureKind::AsyncIterator,
1817+
);
18041818
self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
18051819
}
18061820
};

compiler/rustc_ast_lowering/src/expr/closure.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -327,18 +327,21 @@ impl<'hir> LoweringContext<'_, 'hir> {
327327
// Transform `async |x: u8| -> X { ... }` into
328328
// `|x: u8| || -> X { ... }`.
329329
let body_id = this.lower_body(|this| {
330-
let ((parameters, expr), _) = this.with_move_expr_bindings(None, |this| {
331-
this.lower_coroutine_body_with_moved_arguments(
332-
&inner_decl,
333-
|this| this.with_new_scopes(fn_decl_span, |this| this.lower_expr_mut(body)),
334-
fn_decl_span,
335-
body.span,
336-
coroutine_kind,
337-
hir::CoroutineSource::Closure,
338-
)
339-
});
330+
let ((parameters, expr, coroutine_hir_id), _) =
331+
this.with_move_expr_bindings(None, |this| {
332+
this.lower_coroutine_body_with_moved_arguments(
333+
&inner_decl,
334+
|this| {
335+
this.with_new_scopes(fn_decl_span, |this| this.lower_expr_mut(body))
336+
},
337+
fn_decl_span,
338+
body.span,
339+
coroutine_kind,
340+
hir::CoroutineSource::Closure,
341+
)
342+
});
340343

341-
this.maybe_forward_track_caller(body.span, closure_hir_id, expr.hir_id);
344+
this.maybe_forward_track_caller(body.span, closure_hir_id, coroutine_hir_id);
342345

343346
(parameters, expr)
344347
});

compiler/rustc_ast_lowering/src/item.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,18 +1450,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
14501450
};
14511451
// FIXME(contracts): Support contracts on async fn.
14521452
self.lower_body(|this| {
1453-
let (parameters, expr) = this.lower_coroutine_body_with_moved_arguments(
1454-
decl,
1455-
|this| this.lower_block_expr(body),
1456-
fn_decl_span,
1457-
body.span,
1458-
coroutine_kind,
1459-
hir::CoroutineSource::Fn,
1460-
);
1453+
let (parameters, expr, coroutine_hir_id) = this
1454+
.lower_coroutine_body_with_moved_arguments(
1455+
decl,
1456+
|this| this.lower_block_expr(body),
1457+
fn_decl_span,
1458+
body.span,
1459+
coroutine_kind,
1460+
hir::CoroutineSource::Fn,
1461+
);
14611462

14621463
// FIXME(async_fn_track_caller): Can this be moved above?
1463-
let hir_id = expr.hir_id;
1464-
this.maybe_forward_track_caller(body.span, fn_id, hir_id);
1464+
this.maybe_forward_track_caller(body.span, fn_id, coroutine_hir_id);
14651465

14661466
(parameters, expr)
14671467
})
@@ -1479,7 +1479,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
14791479
body_span: Span,
14801480
coroutine_kind: CoroutineKind,
14811481
coroutine_source: hir::CoroutineSource,
1482-
) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>) {
1482+
) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>, HirId) {
14831483
let mut parameters: Vec<hir::Param<'_>> = Vec::new();
14841484
let mut statements: Vec<hir::Stmt<'_>> = Vec::new();
14851485

@@ -1643,6 +1643,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16431643
CoroutineKind::AsyncGen { .. } => hir::CoroutineDesugaring::AsyncGen,
16441644
};
16451645
let closure_id = coroutine_kind.closure_id();
1646+
let closure_hir_id = self.lower_node_id(closure_id);
16461647

16471648
let coroutine_expr = self.make_desugared_coroutine_expr(
16481649
// The default capture mode here is by-ref. Later on during upvar analysis,
@@ -1651,6 +1652,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16511652
// all async closures would default to `FnOnce` as their calling mode.
16521653
CaptureBy::Ref,
16531654
closure_id,
1655+
closure_hir_id,
16541656
None,
16551657
fn_decl_span,
16561658
body_span,
@@ -1659,13 +1661,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16591661
mkbody,
16601662
);
16611663

1662-
let expr = hir::Expr {
1663-
hir_id: self.lower_node_id(closure_id),
1664-
kind: coroutine_expr,
1665-
span: self.lower_span(body_span),
1666-
};
1667-
1668-
(self.arena.alloc_from_iter(parameters), expr)
1664+
(self.arena.alloc_from_iter(parameters), coroutine_expr, closure_hir_id)
16691665
}
16701666

16711667
fn lower_method_sig(

0 commit comments

Comments
 (0)