@@ -142,6 +142,7 @@ def main():
142142 ap .add_argument ("--dataset" , default = "Contextbench/ContextBench" )
143143 ap .add_argument ("--split" , default = "contextbench_verified" )
144144 ap .add_argument ("--output" , type = str , default = None )
145+ ap .add_argument ("--workers" , type = int , default = 1 )
145146 args = ap .parse_args ()
146147
147148 from datasets import load_dataset
@@ -162,19 +163,30 @@ def main():
162163 all_results : list [dict ] = []
163164 t0 = time .time ()
164165
165- for i , inst in enumerate (multi_file , 1 ):
166+ def _run_one (idx_inst : tuple [int , dict ]) -> list [dict ]:
167+ i , inst = idx_inst
166168 iid = inst ["instance_id" ]
167169 n_files = len (patch_files (inst ["patch" ]))
168- print (f"[{ i } /{ len (multi_file )} ] { iid } ({ n_files } files)" )
169-
170+ print (f"[{ i } /{ len (multi_file )} ] { iid } ({ n_files } files)" , flush = True )
170171 try :
171172 results = evaluate_loo (inst , args .budget )
172173 hits = sum (1 for r in results if r ["found" ])
173174 total = len (results )
174- print (f" LOO: { hits } /{ total } found ({ 100 * hits / max (1 , total ):.0f} %)" )
175- all_results . extend ( results )
175+ print (f" LOO: { hits } /{ total } found ({ 100 * hits / max (1 , total ):.0f} %)" , flush = True )
176+ return results
176177 except Exception as e :
177- print (f" ERROR: { type (e ).__name__ } : { e } " )
178+ print (f" ERROR: { type (e ).__name__ } : { e } " , flush = True )
179+ return []
180+
181+ if args .workers > 1 :
182+ from concurrent .futures import ProcessPoolExecutor
183+
184+ with ProcessPoolExecutor (max_workers = args .workers ) as pool :
185+ for results in pool .map (_run_one , enumerate (multi_file , 1 )):
186+ all_results .extend (results )
187+ else :
188+ for i , inst in enumerate (multi_file , 1 ):
189+ all_results .extend (_run_one ((i , inst )))
178190
179191 elapsed = time .time () - t0
180192 print ()
0 commit comments