Skip to content

Commit 47abf3e

Browse files
committed
perf(bench): split pipeline — compute_scored_state + select_with_params (12x)
1 parent 02fda1c commit 47abf3e

8 files changed

Lines changed: 792 additions & 138 deletions

File tree

benchmarks/adapters/calibrate.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,182 @@ def _make_pool() -> ProcessPoolExecutor:
133133
return out
134134

135135

136+
EvalAllCellsFn = Callable[
137+
[BenchmarkInstance, list[RunParams]],
138+
list[tuple[RunParams, EvalResult]],
139+
]
140+
141+
142+
def evaluate_grid_cached( # noqa: C901 — pool teardown + per-cell demux + retry-on-BPP do not factor cleanly
143+
spec: GridSpec,
144+
instances: list[BenchmarkInstance],
145+
eval_all_cells_fn: EvalAllCellsFn,
146+
workers: int = 1,
147+
on_trial: TrialCallback | None = None,
148+
timeout_per_instance: float = 300.0,
149+
checkpoint_dir: Path | None = None,
150+
) -> list[TrialResult]:
151+
"""Inverted-loop calibration: outer = instance, inner = grid cells.
152+
153+
Each ProcessPool task computes the heavy `ScoredState` ONCE per
154+
instance, then runs all (`tau`, `core_budget_fraction`) cells against
155+
it cheaply. Cuts wall time by ~12x for a 12-cell grid because the
156+
expensive parse/fragment/discover/score work is no longer redone per
157+
cell. State never crosses the pickle boundary — only the resulting
158+
`EvalResult` list does — so ProcessPool is preserved and per-process
159+
memory pressure is bounded.
160+
161+
Per-cell checkpoint files (`<params.label()>.jsonl`) match the
162+
layout produced by `evaluate_grid` so the existing aggregator,
163+
`top_k_trials`, and `render_grid_report` work unchanged.
164+
"""
165+
import multiprocessing as mp
166+
from concurrent.futures import ProcessPoolExecutor, as_completed
167+
from concurrent.futures.process import BrokenProcessPool
168+
169+
from benchmarks.adapters.runner import (
170+
_load_existing_results,
171+
append_checkpoint,
172+
read_checkpoint,
173+
)
174+
175+
evaluator = UniversalEvaluator()
176+
points = list(spec.points())
177+
178+
ckpts: dict[RunParams, Path | None] = {
179+
p: (checkpoint_dir / f"{p.label()}.jsonl") if checkpoint_dir is not None else None for p in points
180+
}
181+
done_ids: dict[RunParams, set[str]] = {p: read_checkpoint(c) if c is not None else set() for p, c in ckpts.items()}
182+
results_by_cell: dict[RunParams, list[EvalResult]] = {
183+
p: (_load_existing_results(c, done_ids[p]) if c is not None else []) for p, c in ckpts.items()
184+
}
185+
186+
pending: list[tuple[BenchmarkInstance, list[RunParams]]] = []
187+
for inst in instances:
188+
needed = [p for p in points if inst.instance_id not in done_ids[p]]
189+
if needed:
190+
pending.append((inst, needed))
191+
192+
def _make_pool() -> ProcessPoolExecutor:
193+
ctx = mp.get_context("spawn")
194+
p = ProcessPoolExecutor(max_workers=workers, mp_context=ctx, max_tasks_per_child=50)
195+
list(p.map(int, range(workers)))
196+
return p
197+
198+
def _record_per_cell(per_cell_results: list[tuple[RunParams, EvalResult]]) -> None:
199+
for params, result in per_cell_results:
200+
ckpt = ckpts.get(params)
201+
if ckpt is not None:
202+
err = str((result.extra or {}).get("error", ""))
203+
if "BrokenProcessPool" not in err:
204+
append_checkpoint(ckpt, result)
205+
results_by_cell[params].append(result)
206+
207+
def _drain(pool: ProcessPoolExecutor) -> None:
208+
futures: dict = {}
209+
submit_failed: list[tuple[BenchmarkInstance, list[RunParams]]] = []
210+
pool_broken = False
211+
for inst, params_list in pending:
212+
try:
213+
futures[pool.submit(eval_all_cells_fn, inst, params_list)] = (inst, params_list)
214+
except BrokenProcessPool:
215+
idx = pending.index((inst, params_list))
216+
submit_failed.extend(pending[idx:])
217+
pool_broken = True
218+
break
219+
outer_deadline = __import__("time").monotonic() + timeout_per_instance * len(points) * max(
220+
1, (len(pending) + workers - 1) // workers
221+
)
222+
completed: set[str] = set()
223+
try:
224+
for future in as_completed(
225+
futures,
226+
timeout=max(0.0, outer_deadline - __import__("time").monotonic()),
227+
):
228+
inst, params_list = futures[future]
229+
try:
230+
per_cell = future.result(timeout=0)
231+
except BrokenProcessPool:
232+
pool_broken = True
233+
per_cell = [(p, _failure_eval(inst, p, "error", "BrokenProcessPool: worker died")) for p in params_list]
234+
except Exception as e:
235+
per_cell = [(p, _failure_eval(inst, p, "error", f"{type(e).__name__}: {e}")) for p in params_list]
236+
completed.add(inst.instance_id)
237+
_record_per_cell(per_cell)
238+
except BrokenProcessPool:
239+
pool_broken = True
240+
for inst, params_list in submit_failed:
241+
_record_per_cell([(p, _failure_eval(inst, p, "error", "BrokenProcessPool: submit failed")) for p in params_list])
242+
if pool_broken:
243+
raise BrokenProcessPool("pool degraded mid-grid")
244+
245+
pool: ProcessPoolExecutor | None = _make_pool() if workers > 1 else None
246+
try:
247+
if pending and pool is not None:
248+
while True:
249+
try:
250+
_drain(pool)
251+
break
252+
except BrokenProcessPool:
253+
try:
254+
pool.shutdown(wait=False, cancel_futures=True)
255+
except Exception:
256+
pass
257+
pool = _make_pool()
258+
# Recompute pending for the rebuild from current
259+
# checkpoint state — instances completed since last
260+
# rebuild should be skipped.
261+
done_ids_now = {p: read_checkpoint(c) if c is not None else set() for p, c in ckpts.items()}
262+
pending[:] = [
263+
(inst, [p for p in points if inst.instance_id not in done_ids_now[p]])
264+
for inst, _ in pending
265+
if any(inst.instance_id not in done_ids_now[p] for p in points)
266+
]
267+
elif pending and pool is None:
268+
# workers == 1: serial fallback
269+
for inst, params_list in pending:
270+
try:
271+
per_cell = eval_all_cells_fn(inst, params_list)
272+
except Exception as e:
273+
per_cell = [(p, _failure_eval(inst, p, "error", f"{type(e).__name__}: {e}")) for p in params_list]
274+
_record_per_cell(per_cell)
275+
finally:
276+
if pool is not None:
277+
pool.shutdown(wait=False, cancel_futures=True)
278+
279+
out: list[TrialResult] = []
280+
for i, params in enumerate(points):
281+
agg = evaluator.aggregate_per_benchmark(results_by_cell[params])
282+
trial = TrialResult(
283+
params=params,
284+
per_benchmark=agg,
285+
raw_results=tuple(results_by_cell[params]),
286+
)
287+
out.append(trial)
288+
if on_trial is not None:
289+
on_trial(i, len(points), trial)
290+
return out
291+
292+
293+
def _failure_eval(
294+
instance: BenchmarkInstance,
295+
params: RunParams,
296+
status: str,
297+
error: str,
298+
) -> EvalResult:
299+
r = EvalResult(
300+
instance_id=instance.instance_id,
301+
source_benchmark=instance.source_benchmark,
302+
file_recall=0.0,
303+
file_precision=0.0,
304+
budget=params.budget,
305+
)
306+
r.extra["status"] = status
307+
r.extra["error"] = error
308+
r.extra["language"] = instance.language
309+
return r
310+
311+
136312
def top_k_trials(trials: Iterable[TrialResult], k: int = 3) -> list[TrialResult]:
137313
"""Pick the k highest-score trials, breaking ties by lower mean tokens."""
138314

benchmarks/calibrate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@
1818
from dataclasses import asdict
1919
from pathlib import Path
2020

21-
from benchmarks.adapters.calibrate import GridSpec, evaluate_grid, render_grid_report, top_k_trials
21+
from benchmarks.adapters.calibrate import (
22+
GridSpec,
23+
evaluate_grid_cached,
24+
render_grid_report,
25+
top_k_trials,
26+
)
2227
from benchmarks.adapters.runner import filter_instances_by_manifest, read_manifest
2328
from benchmarks.adapters.runtime_probe import probe_resources, report_and_maybe_exit
2429
from benchmarks.build_splits import default_calibration_pool_adapters, default_test_adapters
2530
from benchmarks.common import repos_dir as default_repos_dir
26-
from benchmarks.diffctx_eval_fn import make_diffctx_eval_fn
31+
from benchmarks.diffctx_eval_fn import make_diffctx_eval_all_cells_fn
2732

2833

2934
def _parse_floats(s: str) -> tuple[float, ...]:
@@ -108,7 +113,7 @@ def main() -> int:
108113

109114
_prewarm_bare_clones(instances)
110115

111-
eval_fn = make_diffctx_eval_fn(repo_root)
116+
eval_all_cells_fn = make_diffctx_eval_all_cells_fn(repo_root)
112117
args.out.mkdir(parents=True, exist_ok=True)
113118
checkpoint_dir = args.out / "checkpoints"
114119

@@ -119,10 +124,10 @@ def _on_trial(idx: int, total: int, trial) -> None:
119124
f"min(per_benchmark file_recall) = {trial.score:.4f}"
120125
)
121126

122-
trials = evaluate_grid(
127+
trials = evaluate_grid_cached(
123128
spec,
124129
instances,
125-
eval_fn,
130+
eval_all_cells_fn,
126131
workers=args.workers,
127132
on_trial=_on_trial,
128133
timeout_per_instance=args.timeout_per_instance,

benchmarks/diffctx_eval_fn.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,151 @@ def _pool_eval(repos_dir_str: str, instance: BenchmarkInstance, params: RunParam
149149

150150
def make_diffctx_eval_fn(repos_dir: Path):
151151
return functools.partial(_pool_eval, str(repos_dir))
152+
153+
154+
def _build_eval_result_from_output(
155+
output: dict,
156+
instance: BenchmarkInstance,
157+
params: RunParams,
158+
elapsed: float,
159+
evaluator: UniversalEvaluator,
160+
) -> EvalResult:
161+
if output is None:
162+
result = EvalResult(
163+
instance_id=instance.instance_id,
164+
source_benchmark=instance.source_benchmark,
165+
file_recall=0.0,
166+
file_precision=0.0,
167+
budget=params.budget,
168+
elapsed_seconds=elapsed,
169+
)
170+
result.extra["status"] = "diffctx_fail"
171+
return result
172+
fragments = _output_fragments(output)
173+
used_tokens = _compute_used_tokens(output)
174+
selection = SelectionOutput(
175+
selected_files=_selected_files(fragments),
176+
selected_fragments=fragments,
177+
used_tokens=used_tokens,
178+
elapsed_seconds=elapsed,
179+
)
180+
result = evaluator.evaluate(instance, selection, budget=params.budget)
181+
result.used_tokens = used_tokens
182+
result.extra["status"] = "ok"
183+
result.extra["language"] = instance.language
184+
result.extra["fragment_count"] = len(fragments)
185+
latency = output.get("latency") or {}
186+
if latency:
187+
result.extra["latency_total_ms"] = latency.get("total_ms")
188+
result.extra["latency_breakdown"] = {k: v for k, v in latency.items() if k != "total_ms"}
189+
return result
190+
191+
192+
def pool_eval_all_cells(
193+
repos_dir_str: str,
194+
instance: BenchmarkInstance,
195+
params_list: list[RunParams],
196+
) -> list[tuple[RunParams, EvalResult]]:
197+
"""Compute the heavy `ScoredState` ONCE for the instance, then run
198+
every (`tau`, `core_budget_fraction`) cell against it cheaply.
199+
200+
Returns one (params, result) tuple per input params. The orchestrator
201+
demuxes these into per-cell checkpoints. This is the ProcessPool
202+
worker entry point — the entire ScoredState lives only inside this
203+
process and is dropped before return; only EvalResults cross the
204+
pickle boundary.
205+
"""
206+
from benchmarks.common import apply_as_commit, ensure_repo, reset_to_parent
207+
from treemapper.diffctx.pipeline import compute_scored_state, select_with_params
208+
209+
if not params_list:
210+
return []
211+
212+
worktree_dir, evaluator = _ensure_worker_state(repos_dir_str)
213+
214+
repo_url = str(instance.extra.get("repo_url") or f"https://github.com/{instance.repo}")
215+
repo_dir = ensure_repo(repo_url, instance.repo, instance.base_commit, worktree_dir)
216+
if repo_dir is None:
217+
return [
218+
(
219+
p,
220+
_failure_result(instance, p, "clone_fail", "ensure_repo returned None"),
221+
)
222+
for p in params_list
223+
]
224+
225+
# All params in a sweep share scoring_mode (BM25/PPR/Ego/Hybrid is a
226+
# discovery-and-scoring choice, not a (τ, cbf) one). Use the first.
227+
scoring_mode = params_list[0].scoring
228+
229+
out: list[tuple[RunParams, EvalResult]] = []
230+
try:
231+
apply_as_commit(repo_dir, instance.gold_patch, "diffctx-eval-gold")
232+
233+
t_heavy_start = time.perf_counter()
234+
try:
235+
state = compute_scored_state(
236+
repo_dir,
237+
"HEAD~1..HEAD",
238+
scoring_mode=scoring_mode,
239+
)
240+
except Exception as e:
241+
err = f"{type(e).__name__}: {e}"
242+
return [(p, _failure_result(instance, p, "diffctx_fail", err)) for p in params_list]
243+
heavy_elapsed = time.perf_counter() - t_heavy_start
244+
245+
for params in params_list:
246+
prior_env = {k: os.environ.get(k) for k in params.to_env()}
247+
try:
248+
for k, v in params.to_env().items():
249+
os.environ[k] = v
250+
t_select_start = time.perf_counter()
251+
output = select_with_params(
252+
state,
253+
budget_tokens=params.budget,
254+
tau=params.tau,
255+
)
256+
select_elapsed = time.perf_counter() - t_select_start
257+
# Charge the heavy cost to the first cell only — subsequent
258+
# cells reuse the cached state, so they only pay select cost.
259+
charged = heavy_elapsed + select_elapsed if not out else select_elapsed
260+
result = _build_eval_result_from_output(output, instance, params, charged, evaluator)
261+
out.append((params, result))
262+
finally:
263+
for k, v in prior_env.items():
264+
if v is None:
265+
os.environ.pop(k, None)
266+
else:
267+
os.environ[k] = v
268+
finally:
269+
try:
270+
reset_to_parent(repo_dir)
271+
except Exception:
272+
pass
273+
274+
return out
275+
276+
277+
def _failure_result(
278+
instance: BenchmarkInstance,
279+
params: RunParams,
280+
status: str,
281+
error: str,
282+
) -> EvalResult:
283+
r = EvalResult(
284+
instance_id=instance.instance_id,
285+
source_benchmark=instance.source_benchmark,
286+
file_recall=0.0,
287+
file_precision=0.0,
288+
budget=params.budget,
289+
)
290+
r.extra["status"] = status
291+
r.extra["error"] = error
292+
r.extra["language"] = instance.language
293+
return r
294+
295+
296+
def make_diffctx_eval_all_cells_fn(repos_dir: Path):
297+
"""Sibling of `make_diffctx_eval_fn` for the inverted orchestrator
298+
(one task = one instance × N cells)."""
299+
return functools.partial(pool_eval_all_cells, str(repos_dir))
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
from ._diffctx import * # noqa: F403
2-
from ._diffctx import GitError # noqa: F401 # exception class is not picked up by `import *`
2+
from ._diffctx import ( # noqa: F401 # not picked up by `import *`
3+
GitError,
4+
PyScoredState,
5+
compute_scored_state,
6+
select_with_params,
7+
)

0 commit comments

Comments
 (0)