Skip to content

Commit 87964fc

Browse files
committed
Existing methods
1 parent c23c096 commit 87964fc

3 files changed

Lines changed: 55 additions & 27 deletions

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
- Fine-tuned Qwen 3.5 2B, 0.79 F1, ~91% compression
1616
- CLI pipe, Python library, or vLLM server
1717

18+
Existing context pruning tools ([SWE-Pruner](https://github.com/Ayanami1314/swe-pruner), [Zilliz Semantic Highlight](https://huggingface.co/zilliz/semantic-highlight-bilingual-v1), [Provence](https://arxiv.org/abs/2501.16214)) are built for source code or document paragraphs. They don't handle the mixed, unstructured format of tool output (stack traces interleaved with passing tests, grep matches with context lines, build logs with timestamps). Squeez is trained specifically on 14 types of tool output from real SWE-bench workflows.
19+
1820
```bash
1921
pip install squeez
2022
python -m pytest tests/ -v 2>&1 | squeez "find the test failure related to authentication"
@@ -128,6 +130,7 @@ Evaluated on 617 held-out test samples from SWE-bench, across 14 tool types:
128130
|-------|-----------|--------|------|-------------|
129131
| **Squeez-2B** | **0.8043** | **0.8624** | **0.7895** | 0.9150 |
130132
| Qwen 3.5 35B A3B (zero-shot) | 0.7402 | 0.7498 | 0.7000 | 0.9177 |
133+
| Kimi K2 (zero-shot) | 0.6128 | 0.5286 | 0.5344 | 0.9425 |
131134
| Qwen 3.5 2B (untrained) | 0.4154 | 0.5299 | 0.4075 | 0.8197 |
132135
| BM25 (10%) | 0.1277 | 0.2172 | 0.1314 | 0.9036 |
133136
| Random (10%) | 0.0738 | 0.1009 | 0.0697 | 0.9067 |

scripts/evaluate_baselines.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -232,24 +232,19 @@ def baseline_swe_pruner(model, task: str, tool_output: str, threshold: float = 0
232232

233233
def _load_zilliz():
234234
"""Load Zilliz semantic-highlight (needs: pip install transformers torch)."""
235-
import torch
236235
from transformers import AutoModel
237236

238237
model_name = "zilliz/semantic-highlight-bilingual-v1"
239-
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, dtype=torch.float16)
240-
device = "cuda" if torch.cuda.is_available() else "cpu"
241-
model = model.to(device)
238+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
242239
model.eval()
243240
return model
244241

245242

246243
def baseline_zilliz(model, task: str, tool_output: str, threshold: float = 0.5) -> list[str]:
247244
"""Run Zilliz semantic-highlight via get_raw_predictions().
248245
249-
Uses the low-level API to avoid the broken process() path in
250-
transformers 5.2 (build_inputs_with_special_tokens removed).
251-
Each line is passed as a separate context, and per-token pruning
252-
probabilities are averaged per line.
246+
Uses per-line contexts since process() does nltk sentence splitting
247+
which doesn't handle tool output lines well.
253248
"""
254249
import torch
255250

@@ -286,24 +281,48 @@ def _load_gliner2():
286281

287282

288283
def baseline_gliner2(model, task: str, tool_output: str) -> list[str]:
289-
"""Run GLiNER2 span extraction with 'relevant' as the entity label.
284+
"""Run GLiNER2 span extraction — keep any line containing an extracted entity.
290285
291-
Uses the task description as the label description to guide extraction.
292-
Extracted spans are mapped back to line numbers.
286+
Uses the task as a short label to guide entity extraction.
287+
Any line that overlaps with an extracted span is kept.
293288
"""
294289
lines = tool_output.split("\n")
295290
if not lines:
296291
return []
297292

298-
# Use the task as the entity description for guided extraction
299-
result = model.extract_entities(
300-
tool_output,
301-
{"relevant": f"Text relevant to: {task}"},
302-
include_spans=True,
303-
)
293+
# GLiNER2 has a max input length; truncate if needed
294+
max_chars = 10000
295+
text = tool_output[:max_chars] if len(tool_output) > max_chars else tool_output
296+
297+
# GLiNER2 works best with NER-style labels, not query descriptions.
298+
# Use a fixed set of labels covering common relevant patterns in tool output.
299+
labels = [
300+
"error message",
301+
"failed test",
302+
"stack trace",
303+
"warning",
304+
"relevant code",
305+
"file path",
306+
"configuration",
307+
]
308+
309+
try:
310+
result = model.extract_entities(
311+
text,
312+
labels,
313+
include_spans=True,
314+
)
315+
except Exception as e:
316+
logger.debug(f"GLiNER2 extract_entities failed: {e}")
317+
return []
304318

305-
entities = result.get("entities", {}).get("relevant", [])
306-
if not entities:
319+
# result = {'entities': {'label': [{'text': ..., 'start': N, 'end': N}, ...]}}
320+
all_entities = []
321+
for label_entities in result.get("entities", {}).values():
322+
if isinstance(label_entities, list):
323+
all_entities.extend(label_entities)
324+
325+
if not all_entities:
307326
return []
308327

309328
# Build line offset map
@@ -315,9 +334,12 @@ def baseline_gliner2(model, task: str, tool_output: str) -> list[str]:
315334

316335
# Map character spans to line indices
317336
kept_indices = set()
318-
for entity in entities:
319-
span_start = entity.get("start", 0)
320-
span_end = entity.get("end", 0)
337+
for entity in all_entities:
338+
if isinstance(entity, dict):
339+
span_start = entity.get("start", 0)
340+
span_end = entity.get("end", 0)
341+
else:
342+
continue
321343
for i, (lo, hi) in enumerate(line_offsets):
322344
if span_start < hi and span_end > lo:
323345
kept_indices.add(i)

squeez/training/evaluate.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,14 @@ def record_result(result: dict) -> None:
490490
record_result(result)
491491

492492
if (i + 1) % 10 == 0:
493-
logger.info(
494-
f" [{i + 1}/{len(samples)}] "
495-
f"F1={result['span']['f1']:.3f} EM={result['span']['exact_match']:.0f} "
496-
f"ROUGE-L={result['rouge']:.3f}"
497-
)
493+
if "error" not in result:
494+
logger.info(
495+
f" [{i + 1}/{len(samples)}] "
496+
f"F1={result['span']['f1']:.3f} EM={result['span']['exact_match']:.0f} "
497+
f"ROUGE-L={result['rouge']:.3f}"
498+
)
499+
else:
500+
logger.info(f" [{i + 1}/{len(samples)}] (last sample errored)")
498501

499502
# Aggregate
500503
results = {}

0 commit comments

Comments
 (0)