Skip to content

Commit de2ba0b

Browse files
committed
fix: reduce cognitive complexity in benchmarks and scripts (SonarCloud S3776)
1 parent d8fa233 commit de2ba0b

6 files changed

Lines changed: 197 additions & 144 deletions

File tree

QA.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
- CI matrix: Linux/macOS/Windows × Python 3.10-3.14
66
(15 test matrices + 3 lint/arch jobs)
77
- Windows jobs slowest — typically 7-10min vs 2-4min for Linux/macOS
8-
- Windows runners occasionally hang indefinitely (>1hr) —
9-
cancel and rerun with `gh run rerun <id> --failed`
108
- All jobs must pass; no flaky CI tolerance
119

1210
## SonarCloud
@@ -20,18 +18,23 @@
2018
merge with `or` when bodies identical
2119
- Test YAML fixture "Password" triggers false positive VULNERABILITY
2220
(yaml:S2068) — expected
23-
- `dataclasses.replace()` return type: mypy may flag as OK or error
24-
depending on version — check before adding `# type: ignore`
25-
- Lambda capturing loop variable in immediate-use context (e.g. `max()`)
26-
— SonarCloud flags it; suppress with `# noqa: B023` if mypy rejects
27-
the default-arg workaround
21+
- `dataclasses.replace()` return type: mypy (modern) infers correctly —
22+
do NOT add `cast(T, replace(...))`, mypy will flag as redundant-cast;
23+
remove cast and let type inference work
24+
- Lambda capturing loop variable (S1515): use `dict.__getitem__` instead
25+
of `lambda k: d[k]` — simpler and avoids the flag
26+
- S3776 cognitive complexity: SonarCloud counts boolean operators (`and`,
27+
`or`) as separate increments — extracting complex conditions into named
28+
helpers reduces complexity even without deep nesting changes
29+
- `cast` import removal: after removing casts, remove the `typing.cast`
30+
import too or ruff/mypy will flag unused imports
2831

2932
## Test Suite
3033

3134
- Run `python -m pytest --tb=no -q` for quick status
3235
- test_graph.py separate from test_yaml_diff.py — check both
33-
- 87 xfails currently — all strict=False, bidirectional discovery
34-
precision tradeoff
36+
- Many xfails (strict=False) for bidirectional discovery precision tradeoff —
37+
check count trends, not absolute numbers
3538

3639
## Code Review
3740

benchmarks/contextbench_diffctx.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,32 @@ def line_overlap(
216216
}
217217

218218

219+
def _collect_instance_diagnostics(
220+
frag_count: int,
221+
lo_all: dict,
222+
file_recall: float,
223+
gf: set,
224+
sel_files: set,
225+
nontrivial_recall: float,
226+
output: dict,
227+
) -> list[str]:
228+
diagnostics: list[str] = []
229+
if frag_count == 0:
230+
diagnostics.append("WARN: diffctx returned 0 fragments")
231+
if lo_all["line_recall"] < 1e-9 and frag_count > 0:
232+
diagnostics.append("DIAG: line_recall=0 with fragments>0 — possible line parse bug or no file overlap")
233+
if file_recall < 1e-9 and frag_count > 0:
234+
diagnostics.append("DIAG: file_recall=0 with fragments>0 — selected files don't overlap gold at all")
235+
diagnostics.append(f" gold_files: {sorted(gf)[:5]}")
236+
diagnostics.append(f" selected: {sorted(sel_files)[:5]}")
237+
if nontrivial_recall < 1e-9 and frag_count > 5:
238+
diagnostics.append("DIAG: nontrivial_recall=0 — diffctx may only be selecting patch-adjacent files")
239+
unparsed = sum(1 for f in output.get("fragments", []) if parse_lines_field(f.get("lines", "")) is None)
240+
if unparsed:
241+
diagnostics.append(f"DIAG: {unparsed}/{frag_count} fragments have unparseable 'lines' field")
242+
return diagnostics
243+
244+
219245
def evaluate_instance(
220246
inst: dict,
221247
budget: int = 8000,
@@ -301,26 +327,7 @@ def evaluate_instance(
301327
},
302328
}
303329

304-
diagnostics = []
305-
306-
if frag_count == 0:
307-
diagnostics.append("WARN: diffctx returned 0 fragments")
308-
309-
if lo_all["line_recall"] < 1e-9 and frag_count > 0:
310-
diagnostics.append("DIAG: line_recall=0 with fragments>0 — possible line parse bug or no file overlap")
311-
312-
if file_recall < 1e-9 and frag_count > 0:
313-
diagnostics.append("DIAG: file_recall=0 with fragments>0 — selected files don't overlap gold at all")
314-
diagnostics.append(f" gold_files: {sorted(gf)[:5]}")
315-
diagnostics.append(f" selected: {sorted(sel_files)[:5]}")
316-
317-
if nontrivial_recall < 1e-9 and frag_count > 5:
318-
diagnostics.append("DIAG: nontrivial_recall=0 — diffctx may only be selecting patch-adjacent files")
319-
320-
unparsed = sum(1 for f in output.get("fragments", []) if parse_lines_field(f.get("lines", "")) is None)
321-
if unparsed:
322-
diagnostics.append(f"DIAG: {unparsed}/{frag_count} fragments have unparseable 'lines' field")
323-
330+
diagnostics = _collect_instance_diagnostics(frag_count, lo_all, file_recall, gf, sel_files, nontrivial_recall, output)
324331
result["diagnostics"] = diagnostics
325332

326333
print(f"Fragments: {frag_count} | Time: {elapsed:.1f}s")
@@ -333,6 +340,28 @@ def evaluate_instance(
333340
return result
334341

335342

343+
def _print_per_language_breakdown(ok: list[dict], by_lang: dict) -> None:
344+
if len(by_lang) <= 1:
345+
return
346+
print("\nPer-language breakdown:")
347+
for lang in sorted(by_lang):
348+
lr = by_lang[lang]
349+
avg_fr = sum(r["file_recall"] for r in lr) / len(lr)
350+
avg_ntr = sum(r["nontrivial_file_recall"] for r in lr) / len(lr)
351+
avg_lr = sum(r["line_recall"] for r in lr) / len(lr)
352+
print(f" {lang:12s} (n={len(lr):3d}): file_recall={avg_fr:.3f} nontrivial={avg_ntr:.3f} line_recall={avg_lr:.3f}")
353+
354+
355+
def _print_per_repo_breakdown(ok: list[dict], by_repo: dict) -> None:
356+
if len(by_repo) <= 1:
357+
return
358+
print("\nPer-repo breakdown:")
359+
for repo in sorted(by_repo, key=lambda r: -len(by_repo[r])):
360+
rr = by_repo[repo]
361+
avg_ntr = sum(r["nontrivial_file_recall"] for r in rr) / len(rr)
362+
print(f" {repo:30s} (n={len(rr):3d}): nontrivial_recall={avg_ntr:.3f}")
363+
364+
336365
def aggregate(results: list[dict]) -> None:
337366
ok = [r for r in results if r["status"] == "ok"]
338367
if not ok:
@@ -360,26 +389,12 @@ def aggregate(results: list[dict]) -> None:
360389
by_lang: dict[str, list[dict]] = defaultdict(list)
361390
for r in ok:
362391
by_lang[r["language"]].append(r)
363-
364-
if len(by_lang) > 1:
365-
print("\nPer-language breakdown:")
366-
for lang in sorted(by_lang):
367-
lr = by_lang[lang]
368-
avg_fr = sum(r["file_recall"] for r in lr) / len(lr)
369-
avg_ntr = sum(r["nontrivial_file_recall"] for r in lr) / len(lr)
370-
avg_lr = sum(r["line_recall"] for r in lr) / len(lr)
371-
print(f" {lang:12s} (n={len(lr):3d}): file_recall={avg_fr:.3f} nontrivial={avg_ntr:.3f} line_recall={avg_lr:.3f}")
392+
_print_per_language_breakdown(ok, by_lang)
372393

373394
by_repo: dict[str, list[dict]] = defaultdict(list)
374395
for r in ok:
375396
by_repo[r["repo"]].append(r)
376-
377-
if len(by_repo) > 1:
378-
print("\nPer-repo breakdown:")
379-
for repo in sorted(by_repo, key=lambda r: -len(by_repo[r])):
380-
rr = by_repo[repo]
381-
avg_ntr = sum(r["nontrivial_file_recall"] for r in rr) / len(rr)
382-
print(f" {repo:30s} (n={len(rr):3d}): nontrivial_recall={avg_ntr:.3f}")
397+
_print_per_repo_breakdown(ok, by_repo)
383398

384399
zero_frag = sum(1 for r in ok if r["fragments"] == 0)
385400
zero_line = sum(1 for r in ok if r["line_recall"] < 1e-9 and r["fragments"] > 0)

benchmarks/forensic_contextbench.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,22 @@ def _print_nontrivial_report(
219219
print(f" HIT: {f} max_ppr={max_score:.6f}")
220220

221221

222+
def _classify_nontrivial_stages(
223+
nontrivial: set,
224+
selected: set,
225+
sel_dump: set,
226+
fragmented: set,
227+
universe: set,
228+
) -> dict[str, str]:
229+
stage_per_file: dict[str, str] = {}
230+
for f in nontrivial:
231+
if f in selected:
232+
stage_per_file[f] = "selected"
233+
else:
234+
stage_per_file[f] = _classify_failure_stage(f, sel_dump, fragmented, universe)
235+
return stage_per_file
236+
237+
222238
def evaluate_one(inst: dict, budget: int) -> dict:
223239
iid = inst["instance_id"]
224240
print("\n" + "=" * 78)
@@ -293,12 +309,7 @@ def evaluate_one(inst: dict, budget: int) -> dict:
293309
nt_recall = len(nontrivial_hits) / len(nontrivial) if nontrivial else 0.0
294310
patch_coverage = len(p_set & selected) / len(p_set) if p_set else 0.0
295311

296-
stage_per_file: dict[str, str] = {}
297-
for f in nontrivial:
298-
if f in selected:
299-
stage_per_file[f] = "selected"
300-
else:
301-
stage_per_file[f] = _classify_failure_stage(f, sel_dump, fragmented, universe)
312+
stage_per_file = _classify_nontrivial_stages(nontrivial, selected, sel_dump, fragmented, universe)
302313

303314
return {
304315
"id": iid,
@@ -328,6 +339,40 @@ def _print_threshold_sanity_check():
328339
print(f"diffctx _LOW_RELEVANCE_THRESHOLD = {v}", file=sys.stderr)
329340

330341

342+
def _filter_nontrivial_instances(insts: list) -> list:
343+
kept = []
344+
for i in insts:
345+
gb = json.loads(i["gold_context"]) if isinstance(i["gold_context"], str) else i["gold_context"]
346+
gold = {normalize_gold_path(g["file"]) for g in gb}
347+
added, deleted, modified = patch_files_detailed(i["patch"])
348+
if gold - (added | deleted | modified):
349+
kept.append(i)
350+
return kept
351+
352+
353+
def _print_ok_summary(ok: list[dict]) -> None:
354+
print(f"\nAvg patch_coverage: {sum(r['patch_coverage'] for r in ok)/len(ok):.3f}")
355+
print(f"Avg file_recall: {sum(r['file_recall'] for r in ok)/len(ok):.3f}")
356+
print(f"Avg nontrivial: {sum(r['nt_recall'] for r in ok)/len(ok):.3f}")
357+
total_deleted = sum(r["n_deleted_in_patch"] for r in ok)
358+
print(f"Total deleted files across all instances: {total_deleted}")
359+
360+
stages: dict[str, int] = {}
361+
total_nt = 0
362+
for r in ok:
363+
for _f, stage in r.get("stage_per_file", {}).items():
364+
stages[stage] = stages.get(stage, 0) + 1
365+
total_nt += 1
366+
if total_nt:
367+
print(f"\nStage-wise breakdown ({total_nt} nontrivial gold files):")
368+
for stage in sorted(stages, key=lambda s: -stages[s]):
369+
pct = 100 * stages[stage] / total_nt
370+
print(f" {stage:50s}: {stages[stage]:4d} ({pct:5.1f}%)")
371+
372+
print("\nIf patch_coverage < 0.95: BUG — diffctx is losing files from its own diff input.")
373+
print("If patch_coverage > 0.95: not a patch-loss bug, look elsewhere.")
374+
375+
331376
def main():
332377
_print_threshold_sanity_check()
333378

@@ -342,14 +387,7 @@ def main():
342387
ds = load_dataset("Contextbench/ContextBench", "contextbench_verified", split="train")
343388
insts = list(ds)
344389
if args.nontrivial_only:
345-
kept = []
346-
for i in insts:
347-
gb = json.loads(i["gold_context"]) if isinstance(i["gold_context"], str) else i["gold_context"]
348-
gold = {normalize_gold_path(g["file"]) for g in gb}
349-
added, deleted, modified = patch_files_detailed(i["patch"])
350-
if gold - (added | deleted | modified):
351-
kept.append(i)
352-
insts = kept
390+
insts = _filter_nontrivial_instances(insts)
353391
insts = insts[: args.limit]
354392

355393
print(f"Diagnosing {len(insts)} nontrivial instances at budget={args.budget}\n")
@@ -370,26 +408,7 @@ def main():
370408
for r in fail:
371409
print(f" FAIL [{r['status']}]: {r['id']}")
372410
if ok:
373-
print(f"\nAvg patch_coverage: {sum(r['patch_coverage'] for r in ok)/len(ok):.3f}")
374-
print(f"Avg file_recall: {sum(r['file_recall'] for r in ok)/len(ok):.3f}")
375-
print(f"Avg nontrivial: {sum(r['nt_recall'] for r in ok)/len(ok):.3f}")
376-
total_deleted = sum(r["n_deleted_in_patch"] for r in ok)
377-
print(f"Total deleted files across all instances: {total_deleted}")
378-
379-
stages: dict[str, int] = {}
380-
total_nt = 0
381-
for r in ok:
382-
for _f, stage in r.get("stage_per_file", {}).items():
383-
stages[stage] = stages.get(stage, 0) + 1
384-
total_nt += 1
385-
if total_nt:
386-
print(f"\nStage-wise breakdown ({total_nt} nontrivial gold files):")
387-
for stage in sorted(stages, key=lambda s: -stages[s]):
388-
pct = 100 * stages[stage] / total_nt
389-
print(f" {stage:50s}: {stages[stage]:4d} ({pct:5.1f}%)")
390-
391-
print("\nIf patch_coverage < 0.95: BUG — diffctx is losing files from its own diff input.")
392-
print("If patch_coverage > 0.95: not a patch-loss bug, look elsewhere.")
411+
_print_ok_summary(ok)
393412

394413

395414
if __name__ == "__main__":

benchmarks/loo_swebench.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
def strip_file_from_patch(patch_text: str, file_to_hide: str) -> str:
2727
import re
2828

29-
pattern = re.compile(r"^diff --git\s.+?(?=(?:^diff --git\s)|\Z)", re.MULTILINE | re.DOTALL)
3029
hidden_markers = {f"a/{file_to_hide}", f"b/{file_to_hide}"}
3130
kept = []
32-
for m in pattern.finditer(patch_text):
33-
block = m.group()
31+
for block in re.split(r"(?=^diff --git\s)", patch_text, flags=re.MULTILINE):
32+
if not block.startswith("diff --git"):
33+
continue
3434
first_line = block.split("\n", 1)[0]
3535
parts = first_line.split()
3636
if not any(p.strip('"') in hidden_markers for p in parts[2:]):
@@ -168,6 +168,38 @@ def _timeout_handler(_sig: int, _frame: object) -> None:
168168
return []
169169

170170

171+
def _filter_multi_file(insts: list) -> list:
172+
return [
173+
i
174+
for i in insts
175+
if len(patch_files(i["patch"])) >= 2
176+
and not is_mechanical_change(i["patch"])
177+
and not any(is_vendor_or_generated(f) for f in patch_files(i["patch"]))
178+
]
179+
180+
181+
def _print_loo_breakdowns(all_results: list[dict]) -> None:
182+
by_repo: dict[str, list[dict]] = defaultdict(list)
183+
for r in all_results:
184+
by_repo[r["repo"]].append(r)
185+
186+
print("Per-repo breakdown:")
187+
for repo in sorted(by_repo, key=lambda r: len(by_repo[r]), reverse=True):
188+
trials = by_repo[repo]
189+
h = sum(1 for t in trials if t["found"])
190+
print(f" {repo:40s} {h}/{len(trials):3d} ({100 * h / len(trials):.0f}%)")
191+
192+
by_lang: dict[str, list[dict]] = defaultdict(list)
193+
for r in all_results:
194+
by_lang[r["language"]].append(r)
195+
196+
print("\nPer-language breakdown:")
197+
for lang in sorted(by_lang, key=lambda la: len(by_lang[la]), reverse=True):
198+
trials = by_lang[lang]
199+
h = sum(1 for t in trials if t["found"])
200+
print(f" {lang:20s} {h}/{len(trials):3d} ({100 * h / len(trials):.0f}%)")
201+
202+
171203
def main():
172204
ap = argparse.ArgumentParser()
173205
ap.add_argument("--limit", type=int, default=50)
@@ -187,13 +219,7 @@ def main():
187219
ds = load_dataset(args.dataset, args.split, split="train")
188220
insts = list(ds)
189221

190-
multi_file = [
191-
i
192-
for i in insts
193-
if len(patch_files(i["patch"])) >= 2
194-
and not is_mechanical_change(i["patch"])
195-
and not any(is_vendor_or_generated(f) for f in patch_files(i["patch"]))
196-
]
222+
multi_file = _filter_multi_file(insts)
197223
print(f"Total instances: {len(insts)}, multi-file (filtered): {len(multi_file)}")
198224

199225
warm_cache(multi_file)
@@ -235,25 +261,7 @@ def main():
235261
print(f"Found distractor: {distractor_found}/{distractor_total} ({100 * distractor_found / distractor_total:.1f}%)")
236262
print()
237263

238-
by_repo: dict[str, list[dict]] = defaultdict(list)
239-
for r in all_results:
240-
by_repo[r["repo"]].append(r)
241-
242-
print("Per-repo breakdown:")
243-
for repo in sorted(by_repo, key=lambda r: len(by_repo[r]), reverse=True):
244-
trials = by_repo[repo]
245-
h = sum(1 for t in trials if t["found"])
246-
print(f" {repo:40s} {h}/{len(trials):3d} ({100 * h / len(trials):.0f}%)")
247-
248-
by_lang: dict[str, list[dict]] = defaultdict(list)
249-
for r in all_results:
250-
by_lang[r["language"]].append(r)
251-
252-
print("\nPer-language breakdown:")
253-
for lang in sorted(by_lang, key=lambda la: len(by_lang[la]), reverse=True):
254-
trials = by_lang[lang]
255-
h = sum(1 for t in trials if t["found"])
256-
print(f" {lang:20s} {h}/{len(trials):3d} ({100 * h / len(trials):.0f}%)")
264+
_print_loo_breakdowns(all_results)
257265

258266
if len(seeds) == 1:
259267
tag = f"loo_{args.scoring}_n{args.limit}_b{args.budget}"

0 commit comments

Comments
 (0)