Skip to content

Commit 170a66c

Browse files
committed
feat(benchmarks): add LOO leave-one-out benchmark for SWE-bench
1 parent 3eb7b79 commit 170a66c

1 file changed

Lines changed: 221 additions & 0 deletions

File tree

benchmarks/loo_swebench.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
#!/usr/bin/env python3
2+
from __future__ import annotations
3+
4+
import argparse
5+
import json
6+
import os
7+
import random
8+
import subprocess
9+
import tempfile
10+
import time
11+
from collections import defaultdict
12+
from pathlib import Path
13+
14+
REPOS_DIR = Path(tempfile.gettempdir()) / "contextbench_repos"
15+
REPOS_DIR.mkdir(exist_ok=True)
16+
17+
18+
def run_cmd(cmd, cwd=None, check=True, timeout=120):
19+
return subprocess.run(cmd, cwd=cwd, text=True, capture_output=True, check=check, timeout=timeout)
20+
21+
22+
def patch_files(patch: str) -> set[str]:
23+
files: set[str] = set()
24+
for line in patch.splitlines():
25+
if line.startswith("+++ "):
26+
p = line[4:]
27+
if p != "/dev/null":
28+
files.add(p[2:] if p.startswith("b/") else p)
29+
elif line.startswith("--- "):
30+
p = line[4:]
31+
if p != "/dev/null":
32+
files.add(p[2:] if p.startswith("a/") else p)
33+
return files
34+
35+
36+
def strip_file_from_patch(patch_text: str, file_to_hide: str) -> str:
37+
lines = patch_text.split("\n")
38+
result = []
39+
skip = False
40+
for line in lines:
41+
if line.startswith("diff --git "):
42+
skip = f"a/{file_to_hide}" in line or f"b/{file_to_hide}" in line
43+
if not skip:
44+
result.append(line)
45+
return "\n".join(result)
46+
47+
48+
def ensure_repo(repo_url: str, repo_name: str, base_commit: str) -> Path | None:
49+
repo_dir = REPOS_DIR / repo_name.replace("/", "__")
50+
if not repo_dir.exists():
51+
r = run_cmd(["git", "clone", "--quiet", repo_url, str(repo_dir)], check=False, timeout=600)
52+
if r.returncode != 0:
53+
print(f" CLONE FAIL: {r.stderr[:200]}")
54+
return None
55+
r = run_cmd(["git", "-C", str(repo_dir), "checkout", "--force", base_commit], check=False)
56+
if r.returncode != 0:
57+
run_cmd(["git", "-C", str(repo_dir), "fetch", "--all", "--quiet"], check=False, timeout=600)
58+
r = run_cmd(["git", "-C", str(repo_dir), "checkout", "--force", base_commit], check=False)
59+
if r.returncode != 0:
60+
print(f" CHECKOUT FAIL: {r.stderr[:200]}")
61+
return None
62+
return repo_dir
63+
64+
65+
def apply_partial_patch(repo_dir: Path, partial_patch: str) -> bool:
66+
with tempfile.NamedTemporaryFile(mode="w", suffix=".patch", delete=False) as f:
67+
f.write(partial_patch)
68+
patch_path = f.name
69+
try:
70+
r = run_cmd(["git", "-C", str(repo_dir), "apply", "--index", patch_path], check=False)
71+
if r.returncode != 0:
72+
r = run_cmd(["git", "-C", str(repo_dir), "apply", "--index", "--3way", patch_path], check=False)
73+
if r.returncode != 0:
74+
return False
75+
run_cmd(["git", "-C", str(repo_dir), "commit", "-m", "partial", "--allow-empty", "--no-verify"], check=False)
76+
return True
77+
finally:
78+
os.unlink(patch_path)
79+
80+
81+
def run_diffctx(repo_dir: Path, budget: int) -> set[str]:
82+
from treemapper.diffctx.pipeline import build_diff_context
83+
84+
try:
85+
output = build_diff_context(repo_dir, "HEAD~1..HEAD", budget_tokens=budget)
86+
return {f["path"] for f in output.get("fragments", [])}
87+
except Exception:
88+
return set()
89+
90+
91+
def evaluate_loo(inst: dict, budget: int) -> list[dict]:
92+
iid = inst["instance_id"]
93+
all_patch_files = patch_files(inst["patch"])
94+
95+
if len(all_patch_files) < 2:
96+
return []
97+
98+
repo_url = inst.get("repo_url") or f"https://github.com/{inst['repo']}.git"
99+
repo_dir = ensure_repo(repo_url, inst["repo"], inst["base_commit"])
100+
if not repo_dir:
101+
return []
102+
103+
results = []
104+
for hidden in sorted(all_patch_files):
105+
partial = strip_file_from_patch(inst["patch"], hidden)
106+
remaining_files = patch_files(partial)
107+
if not remaining_files:
108+
continue
109+
110+
run_cmd(["git", "-C", str(repo_dir), "checkout", "--force", inst["base_commit"]], check=False)
111+
run_cmd(["git", "-C", str(repo_dir), "clean", "-fd"], check=False)
112+
113+
if not apply_partial_patch(repo_dir, partial):
114+
continue
115+
116+
selected = run_diffctx(repo_dir, budget)
117+
found = hidden in selected
118+
119+
results.append(
120+
{
121+
"instance_id": iid,
122+
"hidden_file": hidden,
123+
"found": found,
124+
"n_patch_files": len(all_patch_files),
125+
"n_remaining": len(remaining_files),
126+
"n_selected": len(selected),
127+
"language": inst.get("language", "unknown"),
128+
"repo": inst["repo"],
129+
}
130+
)
131+
132+
run_cmd(["git", "-C", str(repo_dir), "checkout", "--force", inst["base_commit"]], check=False)
133+
run_cmd(["git", "-C", str(repo_dir), "clean", "-fd"], check=False)
134+
return results
135+
136+
137+
def main():
138+
ap = argparse.ArgumentParser()
139+
ap.add_argument("--limit", type=int, default=50)
140+
ap.add_argument("--budget", type=int, default=8000)
141+
ap.add_argument("--seed", type=int, default=42)
142+
ap.add_argument("--dataset", default="Contextbench/ContextBench")
143+
ap.add_argument("--split", default="contextbench_verified")
144+
ap.add_argument("--output", type=str, default=None)
145+
args = ap.parse_args()
146+
147+
from datasets import load_dataset
148+
149+
ds = load_dataset(args.dataset, args.split, split="train")
150+
insts = list(ds)
151+
152+
multi_file = [i for i in insts if len(patch_files(i["patch"])) >= 2]
153+
print(f"Total instances: {len(insts)}, multi-file: {len(multi_file)}")
154+
155+
rng = random.Random(args.seed)
156+
rng.shuffle(multi_file)
157+
multi_file = multi_file[: args.limit]
158+
159+
print(f"Evaluating LOO on {len(multi_file)} instances (budget={args.budget})")
160+
print()
161+
162+
all_results: list[dict] = []
163+
t0 = time.time()
164+
165+
for i, inst in enumerate(multi_file, 1):
166+
iid = inst["instance_id"]
167+
n_files = len(patch_files(inst["patch"]))
168+
print(f"[{i}/{len(multi_file)}] {iid} ({n_files} files)")
169+
170+
try:
171+
results = evaluate_loo(inst, args.budget)
172+
hits = sum(1 for r in results if r["found"])
173+
total = len(results)
174+
print(f" LOO: {hits}/{total} found ({100 * hits / max(1, total):.0f}%)")
175+
all_results.extend(results)
176+
except Exception as e:
177+
print(f" ERROR: {type(e).__name__}: {e}")
178+
179+
elapsed = time.time() - t0
180+
print()
181+
print("=" * 70)
182+
print(f"LOO RESULTS ({elapsed:.0f}s)")
183+
print("=" * 70)
184+
185+
if not all_results:
186+
print("No results.")
187+
return
188+
189+
total = len(all_results)
190+
found = sum(1 for r in all_results if r["found"])
191+
print(f"Total LOO trials: {total}")
192+
print(f"Found hidden file: {found}/{total} ({100 * found / total:.1f}%)")
193+
print()
194+
195+
by_repo: dict[str, list[dict]] = defaultdict(list)
196+
for r in all_results:
197+
by_repo[r["repo"]].append(r)
198+
199+
print("Per-repo breakdown:")
200+
for repo in sorted(by_repo, key=lambda r: len(by_repo[r]), reverse=True):
201+
trials = by_repo[repo]
202+
h = sum(1 for t in trials if t["found"])
203+
print(f" {repo:40s} {h}/{len(trials):3d} ({100 * h / len(trials):.0f}%)")
204+
205+
by_lang: dict[str, list[dict]] = defaultdict(list)
206+
for r in all_results:
207+
by_lang[r["language"]].append(r)
208+
209+
print("\nPer-language breakdown:")
210+
for lang in sorted(by_lang, key=lambda la: len(by_lang[la]), reverse=True):
211+
trials = by_lang[lang]
212+
h = sum(1 for t in trials if t["found"])
213+
print(f" {lang:20s} {h}/{len(trials):3d} ({100 * h / len(trials):.0f}%)")
214+
215+
if args.output:
216+
Path(args.output).write_text(json.dumps(all_results, indent=2))
217+
print(f"\nResults saved to {args.output}")
218+
219+
220+
if __name__ == "__main__":
221+
main()

0 commit comments

Comments
 (0)