Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/hir-def/src/expr_store/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ impl Body {
let mut params = None;

let mut is_async_fn = false;
let mut is_gen_fn = false;
let InFile { file_id, value: body } = {
match def {
DefWithBodyId::FunctionId(f) => {
let f = f.lookup(db);
let src = f.source(db);
params = src.value.param_list();
is_async_fn = src.value.async_token().is_some();
is_gen_fn = src.value.gen_token().is_some();
src.map(|it| it.body().map(ast::Expr::from))
}
DefWithBodyId::ConstId(c) => {
Expand All @@ -101,7 +103,8 @@ impl Body {
}
};
let module = def.module(db);
let (body, source_map) = lower_body(db, def, file_id, module, params, body, is_async_fn);
let (body, source_map) =
lower_body(db, def, file_id, module, params, body, is_async_fn, is_gen_fn);

(Arc::new(body), source_map)
}
Expand Down
125 changes: 94 additions & 31 deletions crates/hir-def/src/expr_store/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ use crate::{
},
hir::{
Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind,
CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability,
OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
CoroutineKind, CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm,
Movability, OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
generics::GenericParams,
},
item_scope::BuiltinShadowMode,
Expand All @@ -72,6 +72,7 @@ pub(super) fn lower_body(
parameters: Option<ast::ParamList>,
body: Option<ast::Expr>,
is_async_fn: bool,
is_gen_fn: bool,
) -> (Body, BodySourceMap) {
// We cannot leave the root span map empty and let any identifier from it be treated as root,
// because when inside nested macros `SyntaxContextId`s from the outer macro will be interleaved
Expand Down Expand Up @@ -176,6 +177,8 @@ pub(super) fn lower_body(
DefWithBodyId::VariantId(..) => Awaitable::No("enum variant"),
}
},
is_async_fn,
is_gen_fn,
);
collector.store.inference_roots = Some(smallvec![(body_expr, RootExprOrigin::BodyRoot)]);

Expand Down Expand Up @@ -376,12 +379,20 @@ pub(crate) fn lower_function(
expr_collector.lower_type_ref_opt(ret_type.ty(), &mut ExprCollector::impl_trait_allocator)
});

let return_type = if fn_.value.async_token().is_some() {
let path = hir_expand::mod_path::path![core::future::Future];
let return_type = if fn_.value.async_token().is_some() || fn_.value.gen_token().is_some() {
let (path, assoc_name) =
match (fn_.value.async_token().is_some(), fn_.value.gen_token().is_some()) {
(true, true) => {
(hir_expand::mod_path::path![core::async_iter::AsyncIterator], sym::Item)
}
(true, false) => (hir_expand::mod_path::path![core::future::Future], sym::Output),
(false, true) => (hir_expand::mod_path::path![core::iter::Iterator], sym::Item),
(false, false) => unreachable!(),
};
let mut generic_args: Vec<_> =
std::iter::repeat_n(None, path.segments().len() - 1).collect();
let binding = AssociatedTypeBinding {
name: Name::new_symbol_root(sym::Output),
name: Name::new_symbol_root(assoc_name),
args: None,
type_ref: Some(
return_type
Expand Down Expand Up @@ -950,10 +961,11 @@ impl<'db> ExprCollector<'db> {
/// into the body. This is to make sure that the future actually owns the
/// arguments that are passed to the function, and to ensure things like
/// drop order are stable.
fn lower_async_block_with_moved_arguments(
fn lower_coroutine_with_moved_arguments(
&mut self,
params: &mut [PatId],
body: ExprId,
kind: CoroutineKind,
coroutine_source: CoroutineSource,
) -> ExprId {
let mut statements = Vec::new();
Expand Down Expand Up @@ -989,7 +1001,8 @@ impl<'db> ExprCollector<'db> {
*param = pat_id;
}

let async_ = self.async_block(
let coroutine = self.desugared_coroutine_expr(
kind,
coroutine_source,
// The default capture mode here is by-ref. Later on during upvar analysis,
// we will force the captured arguments to by-move, but for async closures,
Expand All @@ -1001,11 +1014,12 @@ impl<'db> ExprCollector<'db> {
Some(body),
);
// It's important that this comes last, see the lowering of async closures for why.
self.alloc_expr_desugared(async_)
self.alloc_expr_desugared(coroutine)
}

fn async_block(
fn desugared_coroutine_expr(
&mut self,
kind: CoroutineKind,
source: CoroutineSource,
capture_by: CaptureBy,
id: Option<BlockId>,
Expand All @@ -1018,7 +1032,7 @@ impl<'db> ExprCollector<'db> {
arg_types: Box::default(),
ret_type: None,
body: block,
closure_kind: ClosureKind::AsyncBlock { source },
closure_kind: ClosureKind::Coroutine { kind, source },
capture_by,
}
}
Expand All @@ -1028,12 +1042,20 @@ impl<'db> ExprCollector<'db> {
params: &mut [PatId],
expr: Option<ast::Expr>,
awaitable: Awaitable,
is_async_fn: bool,
is_gen_fn: bool,
) -> ExprId {
self.awaitable_context.replace(awaitable);
self.with_label_rib(RibKind::Closure, |this| {
let body = this.collect_expr_opt(expr);
if awaitable == Awaitable::Yes {
this.lower_async_block_with_moved_arguments(params, body, CoroutineSource::Fn)
if is_async_fn || is_gen_fn {
let kind = match (is_async_fn, is_gen_fn) {
(true, true) => CoroutineKind::AsyncGen,
(true, false) => CoroutineKind::Async,
(false, true) => CoroutineKind::Gen,
(false, false) => unreachable!(),
};
this.lower_coroutine_with_moved_arguments(params, body, kind, CoroutineSource::Fn)
} else {
body
}
Expand Down Expand Up @@ -1192,7 +1214,44 @@ impl<'db> ExprCollector<'db> {
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::Yes, |this| {
this.collect_block_(e, |this, id, statements, tail| {
this.async_block(
this.desugared_coroutine_expr(
CoroutineKind::Async,
CoroutineSource::Block,
capture_by,
id,
statements,
tail,
)
})
})
})
}
Some(ast::BlockModifier::Gen(_)) => {
let capture_by =
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
this.collect_block_(e, |this, id, statements, tail| {
this.desugared_coroutine_expr(
CoroutineKind::Gen,
CoroutineSource::Block,
capture_by,
id,
statements,
tail,
)
})
})
})
}
Some(ast::BlockModifier::AsyncGen(_)) => {
let capture_by =
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::Yes, |this| {
this.collect_block_(e, |this, id, statements, tail| {
this.desugared_coroutine_expr(
CoroutineKind::AsyncGen,
CoroutineSource::Block,
capture_by,
id,
Expand All @@ -1213,14 +1272,6 @@ impl<'db> ExprCollector<'db> {
})
})
}
// FIXME
Some(ast::BlockModifier::AsyncGen(_)) => {
self.with_awaitable_block(Awaitable::Yes, |this| this.collect_block(e))
}
Some(ast::BlockModifier::Gen(_)) => self
.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
this.collect_block(e)
}),
None => self.collect_block(e),
},
ast::Expr::LoopExpr(e) => {
Expand Down Expand Up @@ -1460,25 +1511,37 @@ impl<'db> ExprCollector<'db> {
};
let mut body = this
.with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body()));

let closure_kind = if this.is_lowering_coroutine {
let movability = if e.static_token().is_some() {
Movability::Static
let kind = {
if e.async_token().is_some() && e.gen_token().is_some() {
Some(CoroutineKind::AsyncGen)
} else if e.async_token().is_some() {
Some(CoroutineKind::Async)
} else if e.gen_token().is_some() {
Some(CoroutineKind::Gen)
} else {
Movability::Movable
};
ClosureKind::Coroutine(movability)
} else if e.async_token().is_some() {
None
}
};

let closure_kind = if let Some(kind) = kind {
// It's important that this expr is allocated immediately before the closure.
// We rely on it for `coroutine_for_closure()`.
body = this.lower_async_block_with_moved_arguments(
body = this.lower_coroutine_with_moved_arguments(
&mut args,
body,
kind,
CoroutineSource::Closure,
);
body_is_bindings_owner = true;

ClosureKind::AsyncClosure
ClosureKind::CoroutineClosure(kind)
} else if this.is_lowering_coroutine {
let movability = if e.static_token().is_some() {
Movability::Static
} else {
Movability::Movable
};
ClosureKind::OldCoroutine(movability)
} else {
ClosureKind::Closure
};
Expand Down
28 changes: 18 additions & 10 deletions crates/hir-def/src/expr_store/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::{
attrs::AttrFlags,
expr_store::path::{GenericArg, GenericArgs},
hir::{
Array, BindingAnnotation, CaptureBy, ClosureKind, Literal, Movability, RecordSpread,
Statement,
Array, BindingAnnotation, CaptureBy, ClosureKind, CoroutineKind, Literal, Movability,
RecordSpread, Statement,
generics::{GenericParams, WherePredicate},
},
lang_item::LangItemTarget,
Expand Down Expand Up @@ -761,28 +761,36 @@ impl Printer<'_> {
let mut body = *body;
let mut print_pipes = true;
match closure_kind {
ClosureKind::Coroutine(Movability::Static) => {
ClosureKind::OldCoroutine(Movability::Static) => {
w!(self, "static ");
}
ClosureKind::AsyncClosure => {
ClosureKind::CoroutineClosure(kind) => {
if let Expr::Closure {
body: inner_body,
closure_kind: ClosureKind::AsyncBlock { .. },
closure_kind: ClosureKind::Coroutine { .. },
..
} = self.store[body]
{
body = inner_body;
} else {
never!("async closure should always have an async block body");
never!("coroutine closure should always have a coroutine body");
}

w!(self, "async ");
match kind {
CoroutineKind::Async => w!(self, "async "),
CoroutineKind::Gen => w!(self, "gen "),
CoroutineKind::AsyncGen => w!(self, "async gen "),
}
}
ClosureKind::AsyncBlock { .. } => {
w!(self, "async ");
ClosureKind::Coroutine { kind, .. } => {
match kind {
CoroutineKind::Async => w!(self, "async "),
CoroutineKind::Gen => w!(self, "gen "),
CoroutineKind::AsyncGen => w!(self, "async gen "),
}
print_pipes = false;
}
ClosureKind::Closure | ClosureKind::Coroutine(Movability::Movable) => (),
ClosureKind::Closure | ClosureKind::OldCoroutine(Movability::Movable) => (),
}
match capture_by {
CaptureBy::Value => {
Expand Down
13 changes: 10 additions & 3 deletions crates/hir-def/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,19 @@ pub enum InlineAsmRegOrRegClass {
RegClass(Symbol),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoroutineKind {
Async,
Gen,
AsyncGen,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClosureKind {
Closure,
Coroutine(Movability),
AsyncBlock { source: CoroutineSource },
AsyncClosure,
OldCoroutine(Movability),
Coroutine { kind: CoroutineKind, source: CoroutineSource },
CoroutineClosure(CoroutineKind),
}

/// In the case of a coroutine created as part of an async/gen construct,
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-def/src/lang_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ language_item_table! { LangItems =>
FnOnceOutput, sym::fn_once_output, TypeAliasId;

Future, sym::future_trait, TraitId;
AsyncIterator, sym::async_iterator, TraitId;
CoroutineState, sym::coroutine_state, EnumId;
Coroutine, sym::coroutine, TraitId;
CoroutineReturn, sym::coroutine_return, TypeAliasId;
Expand Down Expand Up @@ -522,7 +523,6 @@ language_item_table! { LangItems =>
IteratorNext, sym::next, FunctionId;
Iterator, sym::iterator, TraitId;
FusedIterator, sym::fused_iterator, TraitId;
AsyncIterator, sym::async_iterator, TraitId;

PinNewUnchecked, sym::new_unchecked, FunctionId;

Expand Down
Loading