@@ -133,6 +133,182 @@ def _make_pool() -> ProcessPoolExecutor:
133133 return out
134134
135135
136+ EvalAllCellsFn = Callable [
137+ [BenchmarkInstance , list [RunParams ]],
138+ list [tuple [RunParams , EvalResult ]],
139+ ]
140+
141+
142+ def evaluate_grid_cached ( # noqa: C901 — pool teardown + per-cell demux + retry-on-BPP do not factor cleanly
143+ spec : GridSpec ,
144+ instances : list [BenchmarkInstance ],
145+ eval_all_cells_fn : EvalAllCellsFn ,
146+ workers : int = 1 ,
147+ on_trial : TrialCallback | None = None ,
148+ timeout_per_instance : float = 300.0 ,
149+ checkpoint_dir : Path | None = None ,
150+ ) -> list [TrialResult ]:
151+ """Inverted-loop calibration: outer = instance, inner = grid cells.
152+
153+ Each ProcessPool task computes the heavy `ScoredState` ONCE per
154+ instance, then runs all (`tau`, `core_budget_fraction`) cells against
155+ it cheaply. Cuts wall time by ~12x for a 12-cell grid because the
156+ expensive parse/fragment/discover/score work is no longer redone per
157+ cell. State never crosses the pickle boundary — only the resulting
158+ `EvalResult` list does — so ProcessPool is preserved and per-process
159+ memory pressure is bounded.
160+
161+ Per-cell checkpoint files (`<params.label()>.jsonl`) match the
162+ layout produced by `evaluate_grid` so the existing aggregator,
163+ `top_k_trials`, and `render_grid_report` work unchanged.
164+ """
165+ import multiprocessing as mp
166+ from concurrent .futures import ProcessPoolExecutor , as_completed
167+ from concurrent .futures .process import BrokenProcessPool
168+
169+ from benchmarks .adapters .runner import (
170+ _load_existing_results ,
171+ append_checkpoint ,
172+ read_checkpoint ,
173+ )
174+
175+ evaluator = UniversalEvaluator ()
176+ points = list (spec .points ())
177+
178+ ckpts : dict [RunParams , Path | None ] = {
179+ p : (checkpoint_dir / f"{ p .label ()} .jsonl" ) if checkpoint_dir is not None else None for p in points
180+ }
181+ done_ids : dict [RunParams , set [str ]] = {p : read_checkpoint (c ) if c is not None else set () for p , c in ckpts .items ()}
182+ results_by_cell : dict [RunParams , list [EvalResult ]] = {
183+ p : (_load_existing_results (c , done_ids [p ]) if c is not None else []) for p , c in ckpts .items ()
184+ }
185+
186+ pending : list [tuple [BenchmarkInstance , list [RunParams ]]] = []
187+ for inst in instances :
188+ needed = [p for p in points if inst .instance_id not in done_ids [p ]]
189+ if needed :
190+ pending .append ((inst , needed ))
191+
192+ def _make_pool () -> ProcessPoolExecutor :
193+ ctx = mp .get_context ("spawn" )
194+ p = ProcessPoolExecutor (max_workers = workers , mp_context = ctx , max_tasks_per_child = 50 )
195+ list (p .map (int , range (workers )))
196+ return p
197+
198+ def _record_per_cell (per_cell_results : list [tuple [RunParams , EvalResult ]]) -> None :
199+ for params , result in per_cell_results :
200+ ckpt = ckpts .get (params )
201+ if ckpt is not None :
202+ err = str ((result .extra or {}).get ("error" , "" ))
203+ if "BrokenProcessPool" not in err :
204+ append_checkpoint (ckpt , result )
205+ results_by_cell [params ].append (result )
206+
207+ def _drain (pool : ProcessPoolExecutor ) -> None :
208+ futures : dict = {}
209+ submit_failed : list [tuple [BenchmarkInstance , list [RunParams ]]] = []
210+ pool_broken = False
211+ for inst , params_list in pending :
212+ try :
213+ futures [pool .submit (eval_all_cells_fn , inst , params_list )] = (inst , params_list )
214+ except BrokenProcessPool :
215+ idx = pending .index ((inst , params_list ))
216+ submit_failed .extend (pending [idx :])
217+ pool_broken = True
218+ break
219+ outer_deadline = __import__ ("time" ).monotonic () + timeout_per_instance * len (points ) * max (
220+ 1 , (len (pending ) + workers - 1 ) // workers
221+ )
222+ completed : set [str ] = set ()
223+ try :
224+ for future in as_completed (
225+ futures ,
226+ timeout = max (0.0 , outer_deadline - __import__ ("time" ).monotonic ()),
227+ ):
228+ inst , params_list = futures [future ]
229+ try :
230+ per_cell = future .result (timeout = 0 )
231+ except BrokenProcessPool :
232+ pool_broken = True
233+ per_cell = [(p , _failure_eval (inst , p , "error" , "BrokenProcessPool: worker died" )) for p in params_list ]
234+ except Exception as e :
235+ per_cell = [(p , _failure_eval (inst , p , "error" , f"{ type (e ).__name__ } : { e } " )) for p in params_list ]
236+ completed .add (inst .instance_id )
237+ _record_per_cell (per_cell )
238+ except BrokenProcessPool :
239+ pool_broken = True
240+ for inst , params_list in submit_failed :
241+ _record_per_cell ([(p , _failure_eval (inst , p , "error" , "BrokenProcessPool: submit failed" )) for p in params_list ])
242+ if pool_broken :
243+ raise BrokenProcessPool ("pool degraded mid-grid" )
244+
245+ pool : ProcessPoolExecutor | None = _make_pool () if workers > 1 else None
246+ try :
247+ if pending and pool is not None :
248+ while True :
249+ try :
250+ _drain (pool )
251+ break
252+ except BrokenProcessPool :
253+ try :
254+ pool .shutdown (wait = False , cancel_futures = True )
255+ except Exception :
256+ pass
257+ pool = _make_pool ()
258+ # Recompute pending for the rebuild from current
259+ # checkpoint state — instances completed since last
260+ # rebuild should be skipped.
261+ done_ids_now = {p : read_checkpoint (c ) if c is not None else set () for p , c in ckpts .items ()}
262+ pending [:] = [
263+ (inst , [p for p in points if inst .instance_id not in done_ids_now [p ]])
264+ for inst , _ in pending
265+ if any (inst .instance_id not in done_ids_now [p ] for p in points )
266+ ]
267+ elif pending and pool is None :
268+ # workers == 1: serial fallback
269+ for inst , params_list in pending :
270+ try :
271+ per_cell = eval_all_cells_fn (inst , params_list )
272+ except Exception as e :
273+ per_cell = [(p , _failure_eval (inst , p , "error" , f"{ type (e ).__name__ } : { e } " )) for p in params_list ]
274+ _record_per_cell (per_cell )
275+ finally :
276+ if pool is not None :
277+ pool .shutdown (wait = False , cancel_futures = True )
278+
279+ out : list [TrialResult ] = []
280+ for i , params in enumerate (points ):
281+ agg = evaluator .aggregate_per_benchmark (results_by_cell [params ])
282+ trial = TrialResult (
283+ params = params ,
284+ per_benchmark = agg ,
285+ raw_results = tuple (results_by_cell [params ]),
286+ )
287+ out .append (trial )
288+ if on_trial is not None :
289+ on_trial (i , len (points ), trial )
290+ return out
291+
292+
293+ def _failure_eval (
294+ instance : BenchmarkInstance ,
295+ params : RunParams ,
296+ status : str ,
297+ error : str ,
298+ ) -> EvalResult :
299+ r = EvalResult (
300+ instance_id = instance .instance_id ,
301+ source_benchmark = instance .source_benchmark ,
302+ file_recall = 0.0 ,
303+ file_precision = 0.0 ,
304+ budget = params .budget ,
305+ )
306+ r .extra ["status" ] = status
307+ r .extra ["error" ] = error
308+ r .extra ["language" ] = instance .language
309+ return r
310+
311+
136312def top_k_trials (trials : Iterable [TrialResult ], k : int = 3 ) -> list [TrialResult ]:
137313 """Pick the k highest-score trials, breaking ties by lower mean tokens."""
138314
0 commit comments