Skip to content

Commit 37edd96

Browse files
author
LoCoBench Bot
committed
Add docgen quality variation sweep script
1 parent 8e3f19e commit 37edd96

File tree

1 file changed

+350
-0
lines changed

1 file changed

+350
-0
lines changed

scripts/docgen_quality_sweep.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#!/usr/bin/env python3
2+
"""Run a quality-variation sweep across all ccb_docgen tasks."""
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import re
9+
from dataclasses import dataclass
10+
from pathlib import Path
11+
from typing import Callable
12+
13+
14+
ROOT = Path(__file__).resolve().parents[1]
15+
DOCGEN_DIR = ROOT / "benchmarks" / "ccb_docgen"
16+
17+
18+
def check_any(patterns: list[str], text: str) -> bool:
19+
for p in patterns:
20+
try:
21+
if re.search(p, text, re.IGNORECASE | re.DOTALL):
22+
return True
23+
except re.error:
24+
if p.lower() in text.lower():
25+
return True
26+
return False
27+
28+
29+
def check_all(patterns: list[str], text: str) -> bool:
30+
for p in patterns:
31+
try:
32+
if not re.search(p, text, re.IGNORECASE | re.DOTALL):
33+
return False
34+
except re.error:
35+
if p.lower() not in text.lower():
36+
return False
37+
return True
38+
39+
40+
def witness_for_pattern(pattern: str) -> str:
41+
candidates: list[str] = []
42+
p = re.sub(r"\(\?[imsux-]*\)", "", pattern)
43+
p = p.replace("^", " ").replace("$", " ")
44+
candidates.append(p)
45+
candidates.append(p.replace("\\", ""))
46+
47+
cleaned = p
48+
cleaned = re.sub(r"\[(.)[^]]*\]", r"\1", cleaned)
49+
cleaned = cleaned.replace(".*", " ")
50+
cleaned = cleaned.replace(".+", " ")
51+
cleaned = cleaned.replace("\\.", ".")
52+
cleaned = cleaned.replace("|", " ")
53+
cleaned = re.sub(r"[?*+(){}]", " ", cleaned)
54+
cleaned = re.sub(r"\s+", " ", cleaned).strip()
55+
candidates.append(cleaned)
56+
57+
for c in candidates:
58+
if not c:
59+
continue
60+
try:
61+
if re.search(pattern, c, re.IGNORECASE | re.DOTALL):
62+
return c
63+
except re.error:
64+
if pattern.lower() in c.lower():
65+
return c
66+
67+
return "documented behavior"
68+
69+
70+
def text_for_any(patterns: list[str]) -> str:
71+
candidates: list[str] = []
72+
for p in patterns:
73+
w = witness_for_pattern(p)
74+
candidates.extend(
75+
[
76+
p,
77+
w,
78+
f"{w} example",
79+
f"#{w}",
80+
f"## {w}",
81+
]
82+
)
83+
for c in candidates:
84+
if check_any(patterns, c):
85+
return c
86+
return " ".join(witness_for_pattern(p) for p in patterns)
87+
88+
89+
def text_for_all(patterns: list[str]) -> str:
90+
parts: list[str] = []
91+
for p in patterns:
92+
cands = [p, witness_for_pattern(p), f"{witness_for_pattern(p)} detail"]
93+
chosen = None
94+
for c in cands:
95+
if check_any([p], c):
96+
chosen = c
97+
break
98+
parts.append(chosen or witness_for_pattern(p))
99+
return "\n".join(parts)
100+
101+
102+
@dataclass
103+
class Item:
104+
category: str
105+
weight: float
106+
mode: str # any, all, path
107+
values: list[str]
108+
name: str
109+
110+
111+
@dataclass
112+
class Spec:
113+
task_id: str
114+
category_weights: dict[str, float]
115+
categories: dict[str, list[Item]]
116+
hallucination_penalty: bool
117+
118+
def score(self, text: str) -> float:
119+
category_scores: dict[str, float] = {}
120+
for cat, items in self.categories.items():
121+
total = sum(i.weight for i in items)
122+
hit = 0.0
123+
for item in items:
124+
ok = False
125+
if item.mode == "any":
126+
ok = check_any(item.values, text)
127+
elif item.mode == "all":
128+
ok = check_all(item.values, text)
129+
elif item.mode == "path":
130+
ok = item.values[0].lower() in text.lower()
131+
if ok:
132+
hit += item.weight
133+
category_scores[cat] = (hit / total) if total > 0 else 0.0
134+
135+
base = 0.0
136+
for cat, w in self.category_weights.items():
137+
base += category_scores.get(cat, 0.0) * w
138+
139+
if not self.hallucination_penalty:
140+
return max(0.0, min(1.0, base))
141+
142+
# Mirror k8s verifier behavior: penalize invalid *.go paths.
143+
path_candidates = set(re.findall(r"(?:staging/src|pkg|cmd|api)/[A-Za-z0-9_./-]+\.go", text))
144+
invalid = 0
145+
for p in path_candidates:
146+
if not (ROOT / p).exists():
147+
invalid += 1
148+
penalty = 0.0
149+
if path_candidates:
150+
invalid_ratio = invalid / len(path_candidates)
151+
penalty += min(0.35, invalid_ratio * 0.5)
152+
153+
return max(0.0, min(1.0, base - penalty))
154+
155+
156+
def parse_task(task_dir: Path) -> Spec:
157+
task_id = task_dir.name
158+
gt = json.loads((task_dir / "tests" / "ground_truth.json").read_text())
159+
160+
# Architecture / k8s format
161+
if "weights" in gt and "required_topics" in gt:
162+
category_weights = {
163+
"required_topics": float(gt["weights"]["required_topics"]),
164+
"file_references": float(gt["weights"]["file_references"]),
165+
"data_flow": float(gt["weights"]["data_flow"]),
166+
"extension_points": float(gt["weights"]["extension_points"]),
167+
}
168+
categories: dict[str, list[Item]] = {k: [] for k in category_weights}
169+
170+
for raw in gt["required_topics"]:
171+
categories["required_topics"].append(
172+
Item("required_topics", float(raw["weight"]), "any", raw["patterns"], raw.get("id", "topic"))
173+
)
174+
for raw in gt["file_references"]:
175+
categories["file_references"].append(
176+
Item("file_references", float(raw["weight"]), "any", raw["patterns"], raw.get("id", "ref"))
177+
)
178+
for raw in gt["data_flow"]:
179+
mode = "all" if raw.get("ordered") else "all"
180+
categories["data_flow"].append(
181+
Item("data_flow", float(raw["weight"]), mode, raw["patterns"], raw.get("id", "flow"))
182+
)
183+
for raw in gt["extension_points"]:
184+
categories["extension_points"].append(
185+
Item("extension_points", float(raw["weight"]), "any", raw["patterns"], raw.get("id", "ext"))
186+
)
187+
188+
return Spec(
189+
task_id=task_id,
190+
category_weights=category_weights,
191+
categories=categories,
192+
hallucination_penalty=task_id.startswith("docgen-k8s-"),
193+
)
194+
195+
sc = gt["scoring_categories"]
196+
197+
# Architecture-category format (docgen-arch-003)
198+
if "required_topics" in sc and "topics" in sc["required_topics"]:
199+
category_weights = {
200+
"required_topics": float(sc["required_topics"]["weight"]),
201+
"file_references": float(sc["file_references"]["weight"]),
202+
"data_flow": float(sc["data_flow"]["weight"]),
203+
"extension_points": float(sc["extension_points"]["weight"]),
204+
}
205+
categories: dict[str, list[Item]] = {k: [] for k in category_weights}
206+
207+
for raw in sc["required_topics"]["topics"]:
208+
categories["required_topics"].append(
209+
Item("required_topics", float(raw["weight"]), "any", raw["check_any_pattern"], raw["name"])
210+
)
211+
for raw in sc["file_references"]["files"]:
212+
categories["file_references"].append(
213+
Item("file_references", float(raw["weight"]), "path", [raw["path"]], raw["path"])
214+
)
215+
for raw in sc["data_flow"]["flows"]:
216+
categories["data_flow"].append(
217+
Item("data_flow", float(raw["weight"]), "all", raw["check_all_patterns"], raw["name"])
218+
)
219+
for raw in sc["extension_points"]["points"]:
220+
categories["extension_points"].append(
221+
Item("extension_points", float(raw["weight"]), "any", raw["check_any_pattern"], raw["name"])
222+
)
223+
return Spec(task_id=task_id, category_weights=category_weights, categories=categories, hallucination_penalty=False)
224+
225+
# Generic category/items format (api + migration)
226+
category_weights = {k: float(v["weight"]) for k, v in sc.items()}
227+
categories = {}
228+
for cat, data in sc.items():
229+
categories[cat] = [
230+
Item(cat, float(it["weight"]), "any", it["patterns"], it.get("name", f"{cat}_item"))
231+
for it in data["items"]
232+
]
233+
return Spec(task_id=task_id, category_weights=category_weights, categories=categories, hallucination_penalty=False)
234+
235+
236+
def build_variant(spec: Spec, keep_ratio: float, add_hallucination: bool, irrelevant: bool = False) -> str:
237+
if irrelevant:
238+
filler = (
239+
"This document discusses astronomy, music theory, ocean currents, and city planning. "
240+
"It intentionally avoids software implementation details, code paths, and API semantics. "
241+
)
242+
return (filler * 20).strip()
243+
244+
chunks: list[str] = [f"# Generated Documentation for {spec.task_id}", ""]
245+
for cat, items in spec.categories.items():
246+
chunks.append(f"## {cat.replace('_', ' ').title()}")
247+
keep_n = max(1, int(round(len(items) * keep_ratio)))
248+
for item in items[:keep_n]:
249+
if item.mode == "path":
250+
text = item.values[0]
251+
elif item.mode == "all":
252+
text = text_for_all(item.values)
253+
else:
254+
text = text_for_any(item.values) if item.values else item.name
255+
chunks.append(f"- {item.name}: {text}")
256+
chunks.append("")
257+
258+
doc = "\n".join(chunks)
259+
if add_hallucination:
260+
doc += "\n\n## Additional Notes\n- Refer to pkg/imaginary/notreal_controller.go for core algorithm details.\n"
261+
262+
# Keep above minimum length for verifiers with short-doc guard.
263+
if len(doc) < 900:
264+
doc += "\n" + ("Context detail sentence. " * 80)
265+
return doc
266+
267+
268+
def sweep_task(spec: Spec) -> dict[str, float]:
269+
variants = {
270+
"canonical": build_variant(spec, keep_ratio=1.0, add_hallucination=False),
271+
"high": build_variant(spec, keep_ratio=0.8, add_hallucination=False),
272+
"medium": build_variant(spec, keep_ratio=0.55, add_hallucination=False),
273+
"low": build_variant(spec, keep_ratio=0.3, add_hallucination=False),
274+
"irrelevant": build_variant(spec, keep_ratio=0.0, add_hallucination=False, irrelevant=True),
275+
"high_hallucination": build_variant(spec, keep_ratio=0.8, add_hallucination=True),
276+
}
277+
return {name: round(spec.score(text), 4) for name, text in variants.items()}
278+
279+
280+
def summarize(results: dict[str, dict[str, float]]) -> dict[str, int]:
281+
monotonic_ok = 0
282+
canonical_one = 0
283+
hallu_penalized = 0
284+
k8s_count = 0
285+
for task, r in results.items():
286+
if abs(r["canonical"] - 1.0) < 1e-9:
287+
canonical_one += 1
288+
if r["high"] >= r["medium"] >= r["low"] >= r["irrelevant"]:
289+
monotonic_ok += 1
290+
if task.startswith("docgen-k8s-"):
291+
k8s_count += 1
292+
if r["high_hallucination"] < r["high"]:
293+
hallu_penalized += 1
294+
return {
295+
"tasks_total": len(results),
296+
"canonical_1_0": canonical_one,
297+
"monotonic_pass": monotonic_ok,
298+
"k8s_tasks": k8s_count,
299+
"k8s_hallucination_penalized": hallu_penalized,
300+
}
301+
302+
303+
def render_markdown(results: dict[str, dict[str, float]], summary: dict[str, int]) -> str:
304+
lines = [
305+
"# DocGen Quality Variation Sweep",
306+
"",
307+
f"- Tasks: {summary['tasks_total']}",
308+
f"- Canonical scored 1.0: {summary['canonical_1_0']}/{summary['tasks_total']}",
309+
f"- Monotonic quality ordering (high>=medium>=low>=irrelevant): {summary['monotonic_pass']}/{summary['tasks_total']}",
310+
f"- K8s hallucination penalty triggered: {summary['k8s_hallucination_penalized']}/{summary['k8s_tasks']}",
311+
"",
312+
"| Task | Canonical | High | Medium | Low | Irrelevant | High+Hallucination |",
313+
"|---|---:|---:|---:|---:|---:|---:|",
314+
]
315+
for task in sorted(results):
316+
r = results[task]
317+
lines.append(
318+
f"| {task} | {r['canonical']:.2f} | {r['high']:.2f} | {r['medium']:.2f} | "
319+
f"{r['low']:.2f} | {r['irrelevant']:.2f} | {r['high_hallucination']:.2f} |"
320+
)
321+
lines.append("")
322+
return "\n".join(lines)
323+
324+
325+
def main() -> None:
326+
parser = argparse.ArgumentParser(description="Run quality variation sweep on ccb_docgen tasks.")
327+
parser.add_argument("--bench-dir", type=Path, default=DOCGEN_DIR, help="DocGen benchmark directory.")
328+
parser.add_argument("--json-out", type=Path, default=ROOT / "reports" / "docgen_quality_sweep.json")
329+
parser.add_argument("--md-out", type=Path, default=ROOT / "reports" / "docgen_quality_sweep.md")
330+
args = parser.parse_args()
331+
332+
task_dirs = sorted([p for p in args.bench_dir.iterdir() if (p / "tests" / "ground_truth.json").exists()])
333+
results: dict[str, dict[str, float]] = {}
334+
for task_dir in task_dirs:
335+
spec = parse_task(task_dir)
336+
results[spec.task_id] = sweep_task(spec)
337+
338+
summary = summarize(results)
339+
md = render_markdown(results, summary)
340+
341+
args.json_out.parent.mkdir(parents=True, exist_ok=True)
342+
args.md_out.parent.mkdir(parents=True, exist_ok=True)
343+
args.json_out.write_text(json.dumps({"summary": summary, "results": results}, indent=2) + "\n")
344+
args.md_out.write_text(md + "\n")
345+
346+
print(md)
347+
348+
349+
if __name__ == "__main__":
350+
main()

0 commit comments

Comments
 (0)