@@ -961,27 +961,39 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
961961 all_results : List [SourceRecord ] = []
962962 latency_ms : Dict [str , float ] = {}
963963 plan = pipeline .raw_retrieval_plan (req .domains , answer = req .answer )
964+ raw_tasks = []
964965
965966 if "profile" in plan :
966- results , elapsed = await _timed ("profile" , _search_profile , pipeline , user_id )
967- latency_ms ["profile" ] = elapsed
968- all_results .extend (results )
967+ raw_tasks .append ((
968+ "profile" ,
969+ _timed ("profile" , _search_profile , pipeline , user_id , threaded = True ),
970+ ))
969971 if "temporal" in plan :
970- results , elapsed = await _timed ("temporal" , _search_temporal , pipeline , req .query , user_id , req .top_k )
971- latency_ms ["temporal" ] = elapsed
972- all_results .extend (results )
972+ raw_tasks .append ((
973+ "temporal" ,
974+ _timed ("temporal" , _search_temporal , pipeline , req .query , user_id , req .top_k , threaded = True ),
975+ ))
973976 if "summary" in plan :
974- results , elapsed = await _timed ("summary" , _search_summary , pipeline , req .query , user_id , req .top_k )
975- latency_ms ["summary" ] = elapsed
976- all_results .extend (results )
977+ raw_tasks .append ((
978+ "summary" ,
979+ _timed ("summary" , _search_summary , pipeline , req .query , user_id , req .top_k ),
980+ ))
977981 if "snippet" in plan :
978- results , elapsed = await _timed ("snippet" , _search_snippet , pipeline , req .query , user_id , req .top_k )
979- latency_ms ["snippet" ] = elapsed
980- all_results .extend (results )
982+ raw_tasks .append ((
983+ "snippet" ,
984+ _timed ("snippet" , _search_snippet , pipeline , req .query , user_id , req .top_k ),
985+ ))
981986 if "code" in plan :
982- results , elapsed = await _timed ("code" , _search_code , pipeline , req .query , user_id , req .top_k )
983- latency_ms ["code" ] = elapsed
984- all_results .extend (results )
987+ raw_tasks .append ((
988+ "code" ,
989+ _timed ("code" , _search_code , pipeline , req .query , user_id , req .top_k ),
990+ ))
991+
992+ if raw_tasks :
993+ raw_results = await asyncio .gather (* (task for _ , task in raw_tasks ))
994+ for (domain , _ ), (results , elapsed ) in zip (raw_tasks , raw_results ):
995+ latency_ms [domain ] = elapsed
996+ all_results .extend (results )
985997
986998 all_results .sort (key = lambda record : record .score , reverse = True )
987999
0 commit comments