Skip to content

Commit 55c4bca

Browse files
fix
1 parent 1e7decd commit 55c4bca

16 files changed

Lines changed: 392 additions & 39 deletions

File tree

crates/hir-def/src/expr_store/body.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,15 @@ impl Body {
7474
let mut params = None;
7575

7676
let mut is_async_fn = false;
77+
let mut is_gen_fn = false;
7778
let InFile { file_id, value: body } = {
7879
match def {
7980
DefWithBodyId::FunctionId(f) => {
8081
let f = f.lookup(db);
8182
let src = f.source(db);
8283
params = src.value.param_list();
8384
is_async_fn = src.value.async_token().is_some();
85+
is_gen_fn = src.value.gen_token().is_some();
8486
src.map(|it| it.body().map(ast::Expr::from))
8587
}
8688
DefWithBodyId::ConstId(c) => {
@@ -101,7 +103,8 @@ impl Body {
101103
}
102104
};
103105
let module = def.module(db);
104-
let (body, source_map) = lower_body(db, def, file_id, module, params, body, is_async_fn);
106+
let (body, source_map) =
107+
lower_body(db, def, file_id, module, params, body, is_async_fn, is_gen_fn);
105108

106109
(Arc::new(body), source_map)
107110
}

crates/hir-def/src/expr_store/lower.rs

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub(super) fn lower_body(
7171
parameters: Option<ast::ParamList>,
7272
body: Option<ast::Expr>,
7373
is_async_fn: bool,
74+
is_gen_fn: bool,
7475
) -> (Body, BodySourceMap) {
7576
// We cannot leave the root span map empty and let any identifier from it be treated as root,
7677
// because when inside nested macros `SyntaxContextId`s from the outer macro will be interleaved
@@ -175,6 +176,8 @@ pub(super) fn lower_body(
175176
DefWithBodyId::VariantId(..) => Awaitable::No("enum variant"),
176177
}
177178
},
179+
is_async_fn,
180+
is_gen_fn,
178181
);
179182
collector.store.inference_roots = Some(smallvec![(body_expr, RootExprOrigin::BodyRoot)]);
180183

@@ -375,12 +378,20 @@ pub(crate) fn lower_function(
375378
expr_collector.lower_type_ref_opt(ret_type.ty(), &mut ExprCollector::impl_trait_allocator)
376379
});
377380

378-
let return_type = if fn_.value.async_token().is_some() {
379-
let path = hir_expand::mod_path::path![core::future::Future];
381+
let return_type = if fn_.value.async_token().is_some() || fn_.value.gen_token().is_some() {
382+
let (path, assoc_name) =
383+
match (fn_.value.async_token().is_some(), fn_.value.gen_token().is_some()) {
384+
(true, true) => {
385+
(hir_expand::mod_path::path![core::async_iter::AsyncIterator], sym::Item)
386+
}
387+
(true, false) => (hir_expand::mod_path::path![core::future::Future], sym::Output),
388+
(false, true) => (hir_expand::mod_path::path![core::iter::Iterator], sym::Item),
389+
(false, false) => unreachable!(),
390+
};
380391
let mut generic_args: Vec<_> =
381392
std::iter::repeat_n(None, path.segments().len() - 1).collect();
382393
let binding = AssociatedTypeBinding {
383-
name: Name::new_symbol_root(sym::Output),
394+
name: Name::new_symbol_root(assoc_name),
384395
args: None,
385396
type_ref: Some(
386397
return_type
@@ -945,9 +956,15 @@ impl<'db> ExprCollector<'db> {
945956
})
946957
}
947958

948-
/// An `async fn` needs to capture all parameters in the generated `async` block, even if they have
959+
/// Coroutine-like functions need to capture all parameters in the generated block, even if they have
949960
/// non-captured patterns such as wildcards (to ensure consistent drop order).
950-
fn lower_async_fn(&mut self, params: &mut Vec<PatId>, body: ExprId) -> ExprId {
961+
fn lower_coroutine_fn(
962+
&mut self,
963+
params: &mut Vec<PatId>,
964+
body: ExprId,
965+
is_async_fn: bool,
966+
is_gen_fn: bool,
967+
) -> ExprId {
951968
let mut statements = Vec::new();
952969
for param in params {
953970
let name = match self.store.pats[*param] {
@@ -979,19 +996,25 @@ impl<'db> ExprCollector<'db> {
979996
*param = pat_id;
980997
}
981998

982-
let async_ = self.async_block(
983-
CoroutineSource::Fn,
999+
let closure_kind = match (is_async_fn, is_gen_fn) {
1000+
(true, true) => ClosureKind::AsyncGenBlock { source: CoroutineSource::Fn },
1001+
(true, false) => ClosureKind::AsyncBlock { source: CoroutineSource::Fn },
1002+
(false, true) => ClosureKind::GenBlock { source: CoroutineSource::Fn },
1003+
(false, false) => unreachable!(),
1004+
};
1005+
let coroutine = self.coroutine_block(
1006+
closure_kind,
9841007
CaptureBy::Value,
9851008
None,
9861009
statements.into_boxed_slice(),
9871010
Some(body),
9881011
);
989-
self.alloc_expr_desugared(async_)
1012+
self.alloc_expr_desugared(coroutine)
9901013
}
9911014

992-
fn async_block(
1015+
fn coroutine_block(
9931016
&mut self,
994-
source: CoroutineSource,
1017+
closure_kind: ClosureKind,
9951018
capture_by: CaptureBy,
9961019
id: Option<BlockId>,
9971020
statements: Box<[Statement]>,
@@ -1003,7 +1026,7 @@ impl<'db> ExprCollector<'db> {
10031026
arg_types: Box::default(),
10041027
ret_type: None,
10051028
body: block,
1006-
closure_kind: ClosureKind::AsyncBlock { source },
1029+
closure_kind,
10071030
capture_by,
10081031
}
10091032
}
@@ -1013,11 +1036,17 @@ impl<'db> ExprCollector<'db> {
10131036
params: &mut Vec<PatId>,
10141037
expr: Option<ast::Expr>,
10151038
awaitable: Awaitable,
1039+
is_async_fn: bool,
1040+
is_gen_fn: bool,
10161041
) -> ExprId {
10171042
self.awaitable_context.replace(awaitable);
10181043
self.with_label_rib(RibKind::Closure, |this| {
10191044
let body = this.collect_expr_opt(expr);
1020-
if awaitable == Awaitable::Yes { this.lower_async_fn(params, body) } else { body }
1045+
if is_async_fn || is_gen_fn {
1046+
this.lower_coroutine_fn(params, body, is_async_fn, is_gen_fn)
1047+
} else {
1048+
body
1049+
}
10211050
})
10221051
}
10231052

@@ -1173,8 +1202,44 @@ impl<'db> ExprCollector<'db> {
11731202
self.with_label_rib(RibKind::Closure, |this| {
11741203
this.with_awaitable_block(Awaitable::Yes, |this| {
11751204
this.collect_block_(e, |this, id, statements, tail| {
1176-
this.async_block(
1177-
CoroutineSource::Block,
1205+
this.coroutine_block(
1206+
ClosureKind::AsyncBlock { source: CoroutineSource::Block },
1207+
capture_by,
1208+
id,
1209+
statements,
1210+
tail,
1211+
)
1212+
})
1213+
})
1214+
})
1215+
}
1216+
Some(ast::BlockModifier::Gen(_)) => {
1217+
let capture_by =
1218+
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
1219+
self.with_label_rib(RibKind::Closure, |this| {
1220+
this.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
1221+
this.collect_block_(e, |this, id, statements, tail| {
1222+
this.coroutine_block(
1223+
ClosureKind::GenBlock { source: CoroutineSource::Block },
1224+
capture_by,
1225+
id,
1226+
statements,
1227+
tail,
1228+
)
1229+
})
1230+
})
1231+
})
1232+
}
1233+
Some(ast::BlockModifier::AsyncGen(_)) => {
1234+
let capture_by =
1235+
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
1236+
self.with_label_rib(RibKind::Closure, |this| {
1237+
this.with_awaitable_block(Awaitable::Yes, |this| {
1238+
this.collect_block_(e, |this, id, statements, tail| {
1239+
this.coroutine_block(
1240+
ClosureKind::AsyncGenBlock {
1241+
source: CoroutineSource::Block,
1242+
},
11781243
capture_by,
11791244
id,
11801245
statements,
@@ -1194,14 +1259,6 @@ impl<'db> ExprCollector<'db> {
11941259
})
11951260
})
11961261
}
1197-
// FIXME
1198-
Some(ast::BlockModifier::AsyncGen(_)) => {
1199-
self.with_awaitable_block(Awaitable::Yes, |this| this.collect_block(e))
1200-
}
1201-
Some(ast::BlockModifier::Gen(_)) => self
1202-
.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
1203-
this.collect_block(e)
1204-
}),
12051262
None => self.collect_block(e),
12061263
},
12071264
ast::Expr::LoopExpr(e) => {

crates/hir-def/src/expr_store/pretty.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,14 @@ impl Printer<'_> {
785785
w!(self, "async ");
786786
print_pipes = false;
787787
}
788+
ClosureKind::GenBlock { .. } => {
789+
w!(self, "gen ");
790+
print_pipes = false;
791+
}
792+
ClosureKind::AsyncGenBlock { .. } => {
793+
w!(self, "async gen ");
794+
print_pipes = false;
795+
}
788796
ClosureKind::Closure | ClosureKind::Coroutine(Movability::Movable) => (),
789797
}
790798
match capture_by {

crates/hir-def/src/hir.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ pub enum ClosureKind {
529529
Closure,
530530
Coroutine(Movability),
531531
AsyncBlock { source: CoroutineSource },
532+
GenBlock { source: CoroutineSource },
533+
AsyncGenBlock { source: CoroutineSource },
532534
AsyncClosure,
533535
}
534536

crates/hir-def/src/lang_item.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ language_item_table! { LangItems =>
381381
FnOnceOutput, sym::fn_once_output, TypeAliasId;
382382

383383
Future, sym::future_trait, TraitId;
384+
AsyncIterator, sym::async_iterator, TraitId;
384385
CoroutineState, sym::coroutine_state, EnumId;
385386
Coroutine, sym::coroutine, TraitId;
386387
CoroutineReturn, sym::coroutine_return, TypeAliasId;

crates/hir-def/src/signatures.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ bitflags! {
580580
const EXPLICIT_SAFE = 1 << 11;
581581
const HAS_LEGACY_CONST_GENERICS = 1 << 12;
582582
const RUSTC_INTRINSIC = 1 << 13;
583+
const GEN = 1 << 14;
583584
}
584585
}
585586

@@ -638,6 +639,9 @@ impl FunctionSignature {
638639
if source.value.async_token().is_some() {
639640
flags.insert(FnFlags::ASYNC);
640641
}
642+
if source.value.gen_token().is_some() {
643+
flags.insert(FnFlags::GEN);
644+
}
641645
if source.value.const_token().is_some() {
642646
flags.insert(FnFlags::CONST);
643647
}
@@ -701,6 +705,10 @@ impl FunctionSignature {
701705
self.flags.contains(FnFlags::ASYNC)
702706
}
703707

708+
pub fn is_gen(&self) -> bool {
709+
self.flags.contains(FnFlags::GEN)
710+
}
711+
704712
pub fn is_unsafe(&self) -> bool {
705713
self.flags.contains(FnFlags::UNSAFE)
706714
}

crates/hir-expand/src/mod_path.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ macro_rules! __known_path {
427427
(core::range::RangeFrom) => {};
428428
(core::range::RangeInclusive) => {};
429429
(core::range::RangeToInclusive) => {};
430+
(core::async_iter::AsyncIterator) => {};
430431
(core::future::Future) => {};
431432
(core::future::IntoFuture) => {};
432433
(core::fmt::Debug) => {};

crates/hir-ty/src/db.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ impl InternedCoroutineId {
277277
expr,
278278
hir_def::hir::Expr::Closure {
279279
closure_kind: hir_def::hir::ClosureKind::Coroutine(_)
280-
| hir_def::hir::ClosureKind::AsyncBlock { .. },
280+
| hir_def::hir::ClosureKind::AsyncBlock { .. }
281+
| hir_def::hir::ClosureKind::GenBlock { .. }
282+
| hir_def::hir::ClosureKind::AsyncGenBlock { .. },
281283
..
282284
}
283285
),

crates/hir-ty/src/display.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,6 +1707,64 @@ impl<'db> HirDisplay<'db> for Ty<'db> {
17071707
return_ty.hir_fmt(f)?;
17081708
write!(f, ">")?;
17091709
}
1710+
hir_def::hir::Expr::Closure {
1711+
closure_kind: hir_def::hir::ClosureKind::GenBlock { .. },
1712+
..
1713+
} => {
1714+
let iterator_trait = f.lang_items().Iterator;
1715+
let item = iterator_trait.and_then(|t| {
1716+
t.trait_items(db)
1717+
.associated_type_by_name(&Name::new_symbol_root(sym::Item))
1718+
});
1719+
write!(f, "impl ")?;
1720+
if let Some(t) = iterator_trait {
1721+
f.start_location_link(t.into());
1722+
}
1723+
write!(f, "Iterator")?;
1724+
if iterator_trait.is_some() {
1725+
f.end_location_link();
1726+
}
1727+
write!(f, "<")?;
1728+
if let Some(t) = item {
1729+
f.start_location_link(t.into());
1730+
}
1731+
write!(f, "Item")?;
1732+
if item.is_some() {
1733+
f.end_location_link();
1734+
}
1735+
write!(f, " = ")?;
1736+
yield_ty.hir_fmt(f)?;
1737+
write!(f, ">")?;
1738+
}
1739+
hir_def::hir::Expr::Closure {
1740+
closure_kind: hir_def::hir::ClosureKind::AsyncGenBlock { .. },
1741+
..
1742+
} => {
1743+
let async_iterator_trait = f.lang_items().AsyncIterator;
1744+
let item = async_iterator_trait.and_then(|t| {
1745+
t.trait_items(db)
1746+
.associated_type_by_name(&Name::new_symbol_root(sym::Item))
1747+
});
1748+
write!(f, "impl ")?;
1749+
if let Some(t) = async_iterator_trait {
1750+
f.start_location_link(t.into());
1751+
}
1752+
write!(f, "AsyncIterator")?;
1753+
if async_iterator_trait.is_some() {
1754+
f.end_location_link();
1755+
}
1756+
write!(f, "<")?;
1757+
if let Some(t) = item {
1758+
f.start_location_link(t.into());
1759+
}
1760+
write!(f, "Item")?;
1761+
if item.is_some() {
1762+
f.end_location_link();
1763+
}
1764+
write!(f, " = ")?;
1765+
yield_ty.hir_fmt(f)?;
1766+
write!(f, ">")?;
1767+
}
17101768
hir_def::hir::Expr::Closure {
17111769
closure_kind: hir_def::hir::ClosureKind::Coroutine(..),
17121770
..

crates/hir-ty/src/infer/closure.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,14 @@ impl<'db> InferenceContext<'_, 'db> {
147147

148148
(Ty::new_closure(interner, closure_id.into(), closure_args.args), None)
149149
}
150-
ClosureKind::Coroutine(_) | ClosureKind::AsyncBlock { .. } => {
150+
ClosureKind::Coroutine(_)
151+
| ClosureKind::AsyncBlock { .. }
152+
| ClosureKind::GenBlock { .. }
153+
| ClosureKind::AsyncGenBlock { .. } => {
151154
let yield_ty = match closure_kind {
152-
ClosureKind::Coroutine(_) => self.table.next_ty_var(),
155+
ClosureKind::Coroutine(_)
156+
| ClosureKind::GenBlock { .. }
157+
| ClosureKind::AsyncGenBlock { .. } => self.table.next_ty_var(),
153158
ClosureKind::AsyncBlock { .. } => self.types.types.unit,
154159
_ => unreachable!(),
155160
};
@@ -163,7 +168,8 @@ impl<'db> InferenceContext<'_, 'db> {
163168
// later during upvar analysis. Regular coroutines always have the kind
164169
// ty of `().`
165170
let kind_ty = match closure_kind {
166-
ClosureKind::AsyncBlock { source: CoroutineSource::Closure } => {
171+
ClosureKind::AsyncBlock { source: CoroutineSource::Closure }
172+
| ClosureKind::AsyncGenBlock { source: CoroutineSource::Closure } => {
167173
self.table.next_ty_var()
168174
}
169175
_ => self.types.types.unit,
@@ -397,7 +403,9 @@ impl<'db> InferenceContext<'_, 'db> {
397403
}
398404
ClosureKind::Coroutine(_)
399405
| ClosureKind::AsyncClosure
400-
| ClosureKind::AsyncBlock { .. } => (None, None),
406+
| ClosureKind::AsyncBlock { .. }
407+
| ClosureKind::GenBlock { .. }
408+
| ClosureKind::AsyncGenBlock { .. } => (None, None),
401409
},
402410
_ => (None, None),
403411
}

0 commit comments

Comments
 (0)