Skip to content

Commit 448b87f

Browse files
committed
ci: 16-core runner + parallel LOO workers (8x speedup)
1 parent 8f9c8f4 commit 448b87f

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

.github/workflows/benchmark.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ jobs:
7676
python benchmarks/loo_swebench.py \
7777
--limit ${{ inputs.instances }} \
7878
--budget ${{ inputs.budget }} \
79+
--workers 8 \
7980
--output results/loo_ppr.json \
8081
2>&1 | tee results/loo_ppr.txt
8182
@@ -85,6 +86,7 @@ jobs:
8586
python benchmarks/loo_swebench.py \
8687
--limit ${{ inputs.instances }} \
8788
--budget ${{ inputs.budget }} \
89+
--workers 8 \
8890
--output results/loo_ego.json \
8991
2>&1 | tee results/loo_ego.txt
9092
env:

benchmarks/loo_swebench.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def main():
142142
ap.add_argument("--dataset", default="Contextbench/ContextBench")
143143
ap.add_argument("--split", default="contextbench_verified")
144144
ap.add_argument("--output", type=str, default=None)
145+
ap.add_argument("--workers", type=int, default=1)
145146
args = ap.parse_args()
146147

147148
from datasets import load_dataset
@@ -162,19 +163,30 @@ def main():
162163
all_results: list[dict] = []
163164
t0 = time.time()
164165

165-
for i, inst in enumerate(multi_file, 1):
166+
def _run_one(idx_inst: tuple[int, dict]) -> list[dict]:
167+
i, inst = idx_inst
166168
iid = inst["instance_id"]
167169
n_files = len(patch_files(inst["patch"]))
168-
print(f"[{i}/{len(multi_file)}] {iid} ({n_files} files)")
169-
170+
print(f"[{i}/{len(multi_file)}] {iid} ({n_files} files)", flush=True)
170171
try:
171172
results = evaluate_loo(inst, args.budget)
172173
hits = sum(1 for r in results if r["found"])
173174
total = len(results)
174-
print(f" LOO: {hits}/{total} found ({100 * hits / max(1, total):.0f}%)")
175-
all_results.extend(results)
175+
print(f" LOO: {hits}/{total} found ({100 * hits / max(1, total):.0f}%)", flush=True)
176+
return results
176177
except Exception as e:
177-
print(f" ERROR: {type(e).__name__}: {e}")
178+
print(f" ERROR: {type(e).__name__}: {e}", flush=True)
179+
return []
180+
181+
if args.workers > 1:
182+
from concurrent.futures import ProcessPoolExecutor
183+
184+
with ProcessPoolExecutor(max_workers=args.workers) as pool:
185+
for results in pool.map(_run_one, enumerate(multi_file, 1)):
186+
all_results.extend(results)
187+
else:
188+
for i, inst in enumerate(multi_file, 1):
189+
all_results.extend(_run_one((i, inst)))
178190

179191
elapsed = time.time() - t0
180192
print()

0 commit comments

Comments
 (0)