@@ -139,30 +139,28 @@ async def _fetch_and_group_related_data(db: Database, investigation_ids: list[st
139139 """Fetch related data in bulk and group by investigation ID."""
140140 logger .info ("Fetching related data (studies, assays, contacts, etc.)..." )
141141
142- async def collect (gen : AsyncGenerator [dict [str , Any ], None ]) -> list [dict [str , Any ]]:
143- return [row async for row in gen ]
144-
145- # TODO: also here we're using lists, so generators or cursors
146- study_rows = await collect (db .stream_studies (investigation_ids ))
147- assay_rows = await collect (db .stream_assays (investigation_ids ))
148- contact_rows = await collect (db .stream_contacts (investigation_ids ))
149- pub_rows = await collect (db .stream_publications (investigation_ids ))
150- ann_rows = await collect (db .stream_annotation_tables (investigation_ids ))
151-
152- def group (rows : list [dict [str , Any ]]) -> dict [str , list [dict [str , Any ]]]:
142+ async def group_stream (gen : AsyncGenerator [dict [str , Any ], None ]) -> tuple [dict [str , list [dict [str , Any ]]], int ]:
153143 m = defaultdict (list )
154- for r in rows :
144+ count = 0
145+ async for r in gen :
155146 m [str (r ["investigation_ref" ])].append (r )
156- return dict (m )
147+ count += 1
148+ return dict (m ), count
149+
150+ studies_by_inv , study_count = await group_stream (db .stream_studies (investigation_ids ))
151+ assays_by_inv , assay_count = await group_stream (db .stream_assays (investigation_ids ))
152+ contacts_by_inv , _ = await group_stream (db .stream_contacts (investigation_ids ))
153+ pubs_by_inv , _ = await group_stream (db .stream_publications (investigation_ids ))
154+ anns_by_inv , _ = await group_stream (db .stream_annotation_tables (investigation_ids ))
157155
158156 return RelatedDataBatch (
159- studies_by_inv = group ( study_rows ) ,
160- assays_by_inv = group ( assay_rows ) ,
161- contacts_by_inv = group ( contact_rows ) ,
162- pubs_by_inv = group ( pub_rows ) ,
163- anns_by_inv = group ( ann_rows ) ,
164- study_count = len ( study_rows ) ,
165- assay_count = len ( assay_rows ) ,
157+ studies_by_inv = studies_by_inv ,
158+ assays_by_inv = assays_by_inv ,
159+ contacts_by_inv = contacts_by_inv ,
160+ pubs_by_inv = pubs_by_inv ,
161+ anns_by_inv = anns_by_inv ,
162+ study_count = study_count ,
163+ assay_count = assay_count ,
166164 )
167165
168166
0 commit comments