@@ -71,6 +71,7 @@ pub(super) fn lower_body(
7171 parameters : Option < ast:: ParamList > ,
7272 body : Option < ast:: Expr > ,
7373 is_async_fn : bool ,
74+ is_gen_fn : bool ,
7475) -> ( Body , BodySourceMap ) {
7576 // We cannot leave the root span map empty and let any identifier from it be treated as root,
7677 // because when inside nested macros `SyntaxContextId`s from the outer macro will be interleaved
@@ -175,6 +176,8 @@ pub(super) fn lower_body(
175176 DefWithBodyId :: VariantId ( ..) => Awaitable :: No ( "enum variant" ) ,
176177 }
177178 } ,
179+ is_async_fn,
180+ is_gen_fn,
178181 ) ;
179182 collector. store . inference_roots = Some ( smallvec ! [ ( body_expr, RootExprOrigin :: BodyRoot ) ] ) ;
180183
@@ -375,12 +378,20 @@ pub(crate) fn lower_function(
375378 expr_collector. lower_type_ref_opt ( ret_type. ty ( ) , & mut ExprCollector :: impl_trait_allocator)
376379 } ) ;
377380
378- let return_type = if fn_. value . async_token ( ) . is_some ( ) {
379- let path = hir_expand:: mod_path:: path![ core:: future:: Future ] ;
381+ let return_type = if fn_. value . async_token ( ) . is_some ( ) || fn_. value . gen_token ( ) . is_some ( ) {
382+ let ( path, assoc_name) =
383+ match ( fn_. value . async_token ( ) . is_some ( ) , fn_. value . gen_token ( ) . is_some ( ) ) {
384+ ( true , true ) => {
385+ ( hir_expand:: mod_path:: path![ core:: async_iter:: AsyncIterator ] , sym:: Item )
386+ }
387+ ( true , false ) => ( hir_expand:: mod_path:: path![ core:: future:: Future ] , sym:: Output ) ,
388+ ( false , true ) => ( hir_expand:: mod_path:: path![ core:: iter:: Iterator ] , sym:: Item ) ,
389+ ( false , false ) => unreachable ! ( ) ,
390+ } ;
380391 let mut generic_args: Vec < _ > =
381392 std:: iter:: repeat_n ( None , path. segments ( ) . len ( ) - 1 ) . collect ( ) ;
382393 let binding = AssociatedTypeBinding {
383- name : Name :: new_symbol_root ( sym :: Output ) ,
394+ name : Name :: new_symbol_root ( assoc_name ) ,
384395 args : None ,
385396 type_ref : Some (
386397 return_type
@@ -945,9 +956,15 @@ impl<'db> ExprCollector<'db> {
945956 } )
946957 }
947958
948- /// An `async fn` needs to capture all parameters in the generated `async` block, even if they have
959+ /// Coroutine-like functions need to capture all parameters in the generated block, even if they have
949960 /// non-captured patterns such as wildcards (to ensure consistent drop order).
950- fn lower_async_fn ( & mut self , params : & mut Vec < PatId > , body : ExprId ) -> ExprId {
961+ fn lower_coroutine_fn (
962+ & mut self ,
963+ params : & mut Vec < PatId > ,
964+ body : ExprId ,
965+ is_async_fn : bool ,
966+ is_gen_fn : bool ,
967+ ) -> ExprId {
951968 let mut statements = Vec :: new ( ) ;
952969 for param in params {
953970 let name = match self . store . pats [ * param] {
@@ -979,19 +996,25 @@ impl<'db> ExprCollector<'db> {
979996 * param = pat_id;
980997 }
981998
982- let async_ = self . async_block (
983- CoroutineSource :: Fn ,
999+ let closure_kind = match ( is_async_fn, is_gen_fn) {
1000+ ( true , true ) => ClosureKind :: AsyncGenBlock { source : CoroutineSource :: Fn } ,
1001+ ( true , false ) => ClosureKind :: AsyncBlock { source : CoroutineSource :: Fn } ,
1002+ ( false , true ) => ClosureKind :: GenBlock { source : CoroutineSource :: Fn } ,
1003+ ( false , false ) => unreachable ! ( ) ,
1004+ } ;
1005+ let coroutine = self . coroutine_block (
1006+ closure_kind,
9841007 CaptureBy :: Value ,
9851008 None ,
9861009 statements. into_boxed_slice ( ) ,
9871010 Some ( body) ,
9881011 ) ;
989- self . alloc_expr_desugared ( async_ )
1012+ self . alloc_expr_desugared ( coroutine )
9901013 }
9911014
992- fn async_block (
1015+ fn coroutine_block (
9931016 & mut self ,
994- source : CoroutineSource ,
1017+ closure_kind : ClosureKind ,
9951018 capture_by : CaptureBy ,
9961019 id : Option < BlockId > ,
9971020 statements : Box < [ Statement ] > ,
@@ -1003,7 +1026,7 @@ impl<'db> ExprCollector<'db> {
10031026 arg_types : Box :: default ( ) ,
10041027 ret_type : None ,
10051028 body : block,
1006- closure_kind : ClosureKind :: AsyncBlock { source } ,
1029+ closure_kind,
10071030 capture_by,
10081031 }
10091032 }
@@ -1013,11 +1036,17 @@ impl<'db> ExprCollector<'db> {
10131036 params : & mut Vec < PatId > ,
10141037 expr : Option < ast:: Expr > ,
10151038 awaitable : Awaitable ,
1039+ is_async_fn : bool ,
1040+ is_gen_fn : bool ,
10161041 ) -> ExprId {
10171042 self . awaitable_context . replace ( awaitable) ;
10181043 self . with_label_rib ( RibKind :: Closure , |this| {
10191044 let body = this. collect_expr_opt ( expr) ;
1020- if awaitable == Awaitable :: Yes { this. lower_async_fn ( params, body) } else { body }
1045+ if is_async_fn || is_gen_fn {
1046+ this. lower_coroutine_fn ( params, body, is_async_fn, is_gen_fn)
1047+ } else {
1048+ body
1049+ }
10211050 } )
10221051 }
10231052
@@ -1173,8 +1202,44 @@ impl<'db> ExprCollector<'db> {
11731202 self . with_label_rib ( RibKind :: Closure , |this| {
11741203 this. with_awaitable_block ( Awaitable :: Yes , |this| {
11751204 this. collect_block_ ( e, |this, id, statements, tail| {
1176- this. async_block (
1177- CoroutineSource :: Block ,
1205+ this. coroutine_block (
1206+ ClosureKind :: AsyncBlock { source : CoroutineSource :: Block } ,
1207+ capture_by,
1208+ id,
1209+ statements,
1210+ tail,
1211+ )
1212+ } )
1213+ } )
1214+ } )
1215+ }
1216+ Some ( ast:: BlockModifier :: Gen ( _) ) => {
1217+ let capture_by =
1218+ if e. move_token ( ) . is_some ( ) { CaptureBy :: Value } else { CaptureBy :: Ref } ;
1219+ self . with_label_rib ( RibKind :: Closure , |this| {
1220+ this. with_awaitable_block ( Awaitable :: No ( "non-async gen block" ) , |this| {
1221+ this. collect_block_ ( e, |this, id, statements, tail| {
1222+ this. coroutine_block (
1223+ ClosureKind :: GenBlock { source : CoroutineSource :: Block } ,
1224+ capture_by,
1225+ id,
1226+ statements,
1227+ tail,
1228+ )
1229+ } )
1230+ } )
1231+ } )
1232+ }
1233+ Some ( ast:: BlockModifier :: AsyncGen ( _) ) => {
1234+ let capture_by =
1235+ if e. move_token ( ) . is_some ( ) { CaptureBy :: Value } else { CaptureBy :: Ref } ;
1236+ self . with_label_rib ( RibKind :: Closure , |this| {
1237+ this. with_awaitable_block ( Awaitable :: Yes , |this| {
1238+ this. collect_block_ ( e, |this, id, statements, tail| {
1239+ this. coroutine_block (
1240+ ClosureKind :: AsyncGenBlock {
1241+ source : CoroutineSource :: Block ,
1242+ } ,
11781243 capture_by,
11791244 id,
11801245 statements,
@@ -1194,14 +1259,6 @@ impl<'db> ExprCollector<'db> {
11941259 } )
11951260 } )
11961261 }
1197- // FIXME
1198- Some ( ast:: BlockModifier :: AsyncGen ( _) ) => {
1199- self . with_awaitable_block ( Awaitable :: Yes , |this| this. collect_block ( e) )
1200- }
1201- Some ( ast:: BlockModifier :: Gen ( _) ) => self
1202- . with_awaitable_block ( Awaitable :: No ( "non-async gen block" ) , |this| {
1203- this. collect_block ( e)
1204- } ) ,
12051262 None => self . collect_block ( e) ,
12061263 } ,
12071264 ast:: Expr :: LoopExpr ( e) => {
0 commit comments