@@ -163,17 +163,36 @@ def _record(r: EvalResult) -> None:
163163 if checkpoint_path is not None :
164164 append_checkpoint (checkpoint_path , r )
165165
166- if not pending :
167- return results
166+ if pending :
167+ if workers <= 1 or len (pending ) <= 1 :
168+ _run_serial (pending , eval_fn , params , _record )
169+ else :
170+ _run_parallel (pending , eval_fn , params , workers , timeout_per_instance , _record )
168171
169- if workers <= 1 or len (pending ) <= 1 :
170- for inst in pending :
171- try :
172- _record (eval_fn (inst , params ))
173- except Exception as e :
174- _record (_failure_result (inst , params , "error" , f"{ type (e ).__name__ } : { e } " ))
175- return results
172+ return results
173+
174+
175+ def _run_serial (
176+ pending : list [BenchmarkInstance ],
177+ eval_fn : EvalFn ,
178+ params : RunParams ,
179+ record : Callable [[EvalResult ], None ],
180+ ) -> None :
181+ for inst in pending :
182+ try :
183+ record (eval_fn (inst , params ))
184+ except Exception as e :
185+ record (_failure_result (inst , params , "error" , f"{ type (e ).__name__ } : { e } " ))
176186
187+
188+ def _run_parallel (
189+ pending : list [BenchmarkInstance ],
190+ eval_fn : EvalFn ,
191+ params : RunParams ,
192+ workers : int ,
193+ timeout_per_instance : float ,
194+ record : Callable [[EvalResult ], None ],
195+ ) -> None :
177196 from concurrent .futures import (
178197 ThreadPoolExecutor ,
179198 as_completed ,
@@ -184,8 +203,6 @@ def _record(r: EvalResult) -> None:
184203
185204 with ThreadPoolExecutor (max_workers = workers ) as pool :
186205 futures = {pool .submit (eval_fn , inst , params ): inst for inst in pending }
187- # Generous outer deadline: timeout * ceil(len/workers) covers the
188- # serialised case if workers all hang together.
189206 outer_deadline = time .monotonic () + timeout_per_instance * max (1 , (len (pending ) + workers - 1 ) // workers )
190207 completed : set [str ] = set ()
191208 try :
@@ -198,12 +215,9 @@ def _record(r: EvalResult) -> None:
198215 except Exception as e :
199216 r = _failure_result (inst , params , "error" , f"{ type (e ).__name__ } : { e } " )
200217 completed .add (inst .instance_id )
201- _record (r )
218+ record (r )
202219 except FuturesTimeoutError :
203220 for inst in futures .values ():
204221 if inst .instance_id not in completed :
205- _record (_failure_result (inst , params , "timeout" , "exceeded global deadline" ))
206- # Cancel any pending futures; running threads are abandoned.
222+ record (_failure_result (inst , params , "timeout" , "exceeded global deadline" ))
207223 pool .shutdown (wait = False , cancel_futures = True )
208-
209- return results
0 commit comments