Skip to content

Commit 6d5e8e5

Browse files
committed
Add data generation pipeline
SWE-bench based pipeline: loader, source fetcher, tool call generator/executor, auto labeler, LLM distiller, sample assembler, and validator.
1 parent e24f3e2 commit 6d5e8e5

12 files changed

Lines changed: 2723 additions & 3 deletions

.gitignore

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ env/
2525
.DS_Store
2626
Thumbs.db
2727

28-
# Project-specific
29-
data/
30-
output/
28+
# Project-specific (only root-level data/, not squeez/data/)
29+
/data/
30+
/output/
3131
*.jsonl
3232
wandb/
3333
runs/

squeez/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Data generation pipeline for tool output extraction training data."""

squeez/data/auto_labeler.py

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
"""Phase 5: Auto-label tool output lines using patch ground truth.
2+
3+
Uses the patch diff to determine which lines in each tool output
4+
are relevant to the bug/fix. Applies quality filters to reject
5+
samples that are too easy or too hard.
6+
"""
7+
8+
import json
9+
import logging
10+
import re
11+
from pathlib import Path
12+
13+
from squeez.data.config import (
14+
MAX_RELEVANT_RATIO,
15+
MIN_RELEVANT_LINES,
16+
MIN_RELEVANT_RATIO,
17+
MIN_TOTAL_LINES,
18+
PipelineConfig,
19+
)
20+
from squeez.data.swebench_loader import parse_patch_files, parse_patch_hunks
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def _find_enclosing_scope(content: str, target_line: int) -> list[int]:
26+
"""Find the enclosing function/class definition lines for a target line.
27+
28+
Simple heuristic: walk backwards from target_line looking for
29+
def/class lines with less indentation.
30+
"""
31+
lines = content.split("\n")
32+
if target_line < 1 or target_line > len(lines):
33+
return []
34+
35+
target_indent = len(lines[target_line - 1]) - len(lines[target_line - 1].lstrip())
36+
scope_lines = []
37+
38+
for i in range(target_line - 2, -1, -1):
39+
line = lines[i]
40+
stripped = line.lstrip()
41+
if not stripped:
42+
continue
43+
indent = len(line) - len(stripped)
44+
if indent < target_indent and (
45+
stripped.startswith("def ") or stripped.startswith("class ")
46+
):
47+
scope_lines.append(i + 1) # 1-indexed
48+
target_indent = indent
49+
if indent == 0:
50+
break
51+
52+
return scope_lines
53+
54+
55+
def _find_imports_for_names(content: str, names: set[str]) -> list[int]:
56+
"""Find import lines that reference any of the given names."""
57+
import_lines = []
58+
for i, line in enumerate(content.split("\n"), 1):
59+
stripped = line.strip()
60+
if stripped.startswith("import ") or stripped.startswith("from "):
61+
for name in names:
62+
if name in stripped:
63+
import_lines.append(i)
64+
break
65+
return import_lines
66+
67+
68+
def label_read_file(
69+
output: str, instance: dict, tool_output: dict
70+
) -> dict[int, bool] | None:
71+
"""Label read_file output lines using patch hunks.
72+
73+
Relevant lines:
74+
- Lines that appear in patch hunks (modified/added)
75+
- Enclosing function/class headers
76+
- Related import lines
77+
"""
78+
target_file = tool_output.get("command", "")
79+
hunks = parse_patch_hunks(instance["patch"])
80+
81+
# Get hunk ranges for this file
82+
file_hunks = hunks.get(target_file, [])
83+
if not file_hunks and not tool_output.get("is_patch_file", False):
84+
# Decoy file — mark all lines as irrelevant
85+
labels = {}
86+
for i in range(1, output.count("\n") + 2):
87+
labels[i] = False
88+
return labels
89+
90+
# Find relevant line numbers from hunks
91+
relevant_lines: set[int] = set()
92+
for start, end in file_hunks:
93+
for line_no in range(start, end + 1):
94+
relevant_lines.add(line_no)
95+
96+
# Parse the original content (strip line numbers from output)
97+
content_lines = []
98+
for line in output.split("\n"):
99+
match = re.match(r"^\d+: (.*)$", line)
100+
if match:
101+
content_lines.append(match.group(1))
102+
else:
103+
content_lines.append(line)
104+
content = "\n".join(content_lines)
105+
106+
# Add enclosing scope lines
107+
scope_lines: set[int] = set()
108+
for line_no in relevant_lines:
109+
scope_lines.update(_find_enclosing_scope(content, line_no))
110+
relevant_lines.update(scope_lines)
111+
112+
# Add related imports
113+
patch_names: set[str] = set()
114+
patch_names.update(re.findall(r"def (\w+)", instance["patch"]))
115+
patch_names.update(re.findall(r"class (\w+)", instance["patch"]))
116+
import_lines = _find_imports_for_names(content, patch_names)
117+
relevant_lines.update(import_lines)
118+
119+
# Build label dict
120+
total_lines = len(content_lines)
121+
labels = {}
122+
for i in range(1, total_lines + 1):
123+
labels[i] = i in relevant_lines
124+
125+
return labels
126+
127+
128+
def label_grep(output: str, instance: dict) -> dict[int, bool] | None:
129+
"""Label grep output lines.
130+
131+
Relevant: matches in patch files at or near patch line ranges.
132+
"""
133+
patch_files = set(parse_patch_files(instance["patch"]))
134+
hunks = parse_patch_hunks(instance["patch"])
135+
136+
labels = {}
137+
for i, line in enumerate(output.split("\n"), 1):
138+
# Grep output format: file:line_no: content
139+
match = re.match(r"^([^:]+):(\d+):", line)
140+
if match:
141+
file_path = match.group(1)
142+
line_no = int(match.group(2))
143+
is_relevant = False
144+
145+
if file_path in patch_files:
146+
# Check if near a patch hunk
147+
file_hunks = hunks.get(file_path, [])
148+
for start, end in file_hunks:
149+
if start - 10 <= line_no <= end + 10:
150+
is_relevant = True
151+
break
152+
# If no hunks found but file is in patch, still somewhat relevant
153+
if not file_hunks:
154+
is_relevant = True
155+
156+
labels[i] = is_relevant
157+
else:
158+
labels[i] = False
159+
160+
return labels
161+
162+
163+
def label_git_log(output: str, instance: dict) -> dict[int, bool] | None:
164+
"""Label git log output. All lines marked as mildly relevant (simulated)."""
165+
labels = {}
166+
lines = output.split("\n")
167+
identifiers = set(re.findall(r"def (\w+)", instance["patch"]))
168+
identifiers.update(re.findall(r"class (\w+)", instance["patch"]))
169+
170+
for i, line in enumerate(lines, 1):
171+
# Mark commits mentioning relevant identifiers as relevant
172+
is_relevant = any(ident.lower() in line.lower() for ident in identifiers)
173+
labels[i] = is_relevant
174+
175+
return labels
176+
177+
178+
def label_test_output(output: str, instance: dict) -> dict[int, bool] | None:
179+
"""Label test output. FAIL lines and tracebacks are relevant."""
180+
labels = {}
181+
in_failure = False
182+
183+
for i, line in enumerate(output.split("\n"), 1):
184+
if line.startswith("FAIL:") or line.startswith("ERROR:"):
185+
in_failure = True
186+
labels[i] = True
187+
elif line.startswith("---") and in_failure:
188+
labels[i] = True
189+
in_failure = False
190+
elif in_failure:
191+
labels[i] = True
192+
elif "FAILED" in line or "Error" in line:
193+
labels[i] = True
194+
else:
195+
labels[i] = False
196+
197+
return labels
198+
199+
200+
def label_git_diff(output: str, instance: dict) -> dict[int, bool] | None:
201+
"""Label git diff output. Changed lines (+/-) are relevant."""
202+
labels = {}
203+
for i, line in enumerate(output.split("\n"), 1):
204+
if line.startswith("+") or line.startswith("-"):
205+
labels[i] = not line.startswith("+++") and not line.startswith("---")
206+
elif line.startswith("@@"):
207+
labels[i] = True
208+
elif line.startswith("diff --git"):
209+
labels[i] = True
210+
else:
211+
labels[i] = False
212+
return labels
213+
214+
215+
def label_ls(output: str, instance: dict) -> dict[int, bool] | None:
216+
"""Label ls output. Files in the patch are relevant."""
217+
patch_files = set(parse_patch_files(instance["patch"]))
218+
patch_names = {Path(f).name for f in patch_files}
219+
220+
labels = {}
221+
for i, line in enumerate(output.split("\n"), 1):
222+
is_relevant = any(name in line for name in patch_names)
223+
labels[i] = is_relevant
224+
return labels
225+
226+
227+
def label_generic(output: str, instance: dict) -> dict[int, bool] | None:
228+
"""Generic labeling for lint/blame/build output.
229+
230+
Marks lines containing patch file names or identifiers as relevant.
231+
"""
232+
patch_files = set(parse_patch_files(instance["patch"]))
233+
identifiers = set(re.findall(r"def (\w+)", instance["patch"]))
234+
identifiers.update(re.findall(r"class (\w+)", instance["patch"]))
235+
all_markers = patch_files | identifiers
236+
237+
labels = {}
238+
for i, line in enumerate(output.split("\n"), 1):
239+
is_relevant = any(marker in line for marker in all_markers)
240+
labels[i] = is_relevant
241+
return labels
242+
243+
244+
def auto_label_output(
245+
tool_output: dict, instance: dict
246+
) -> dict | None:
247+
"""Auto-label a single tool output using patch ground truth.
248+
249+
Returns a dict with labels and metadata, or None if quality filters reject it.
250+
"""
251+
output = tool_output["output"]
252+
tool_type = tool_output["tool_type"]
253+
total_lines = len(output.split("\n"))
254+
255+
# Skip if too short
256+
if total_lines < MIN_TOTAL_LINES:
257+
return None
258+
259+
# Select labeler by tool type
260+
labelers = {
261+
"read_file": lambda: label_read_file(output, instance, tool_output),
262+
"grep": lambda: label_grep(output, instance),
263+
"git_log": lambda: label_git_log(output, instance),
264+
"test_output": lambda: label_test_output(output, instance),
265+
"git_diff": lambda: label_git_diff(output, instance),
266+
"git_blame": lambda: label_generic(output, instance),
267+
"ls": lambda: label_ls(output, instance),
268+
"lint_output": lambda: label_generic(output, instance),
269+
"build_output": lambda: label_generic(output, instance),
270+
}
271+
272+
labeler = labelers.get(tool_type)
273+
if not labeler:
274+
return None
275+
276+
labels = labeler()
277+
if labels is None:
278+
return None
279+
280+
# Compute stats
281+
n_relevant = sum(1 for v in labels.values() if v)
282+
relevant_ratio = n_relevant / total_lines if total_lines > 0 else 0
283+
284+
# Quality filters
285+
if n_relevant < MIN_RELEVANT_LINES:
286+
logger.debug(f"Rejected: too few relevant lines ({n_relevant})")
287+
return None
288+
if relevant_ratio > MAX_RELEVANT_RATIO:
289+
logger.debug(f"Rejected: too many relevant lines ({relevant_ratio:.2%})")
290+
return None
291+
if relevant_ratio < MIN_RELEVANT_RATIO:
292+
logger.debug(f"Rejected: too few relevant ratio ({relevant_ratio:.2%})")
293+
return None
294+
295+
return {
296+
"instance_id": tool_output["instance_id"],
297+
"tool_type": tool_type,
298+
"command": tool_output.get("command", ""),
299+
"output": output,
300+
"labels": {str(k): v for k, v in labels.items()},
301+
"num_total_lines": total_lines,
302+
"num_relevant_lines": n_relevant,
303+
"relevant_ratio": round(relevant_ratio, 4),
304+
}
305+
306+
307+
def auto_label_all(
308+
tool_outputs: list[dict],
309+
instances: list[dict],
310+
config: PipelineConfig,
311+
) -> list[dict]:
312+
"""Auto-label all tool outputs.
313+
314+
Args:
315+
tool_outputs: List of executed tool output dicts
316+
instances: List of SWE-bench instance dicts
317+
config: Pipeline config
318+
319+
Returns:
320+
List of labeled sample dicts
321+
"""
322+
output_path = config.output_dir / "auto_labels.jsonl"
323+
324+
# Return cached
325+
if output_path.exists():
326+
logger.info(f"Loading cached auto-labels from {output_path}")
327+
labels = []
328+
with open(output_path) as f:
329+
for line in f:
330+
labels.append(json.loads(line))
331+
return labels
332+
333+
instance_map = {inst["instance_id"]: inst for inst in instances}
334+
335+
labeled = []
336+
rejected = 0
337+
for tool_output in tool_outputs:
338+
instance_id = tool_output["instance_id"]
339+
instance = instance_map.get(instance_id)
340+
if not instance:
341+
continue
342+
343+
result = auto_label_output(tool_output, instance)
344+
if result:
345+
labeled.append(result)
346+
else:
347+
rejected += 1
348+
349+
# Write to disk
350+
with open(output_path, "w") as f:
351+
for item in labeled:
352+
f.write(json.dumps(item) + "\n")
353+
354+
logger.info(
355+
f"Auto-labeled {len(labeled)} samples ({rejected} rejected by quality filters)"
356+
)
357+
return labeled

0 commit comments

Comments
 (0)