Skip to content

Commit 29e595e

Browse files
committed
Better docs, encoding
1 parent bc25f7d commit 29e595e

5 files changed

Lines changed: 254 additions & 44 deletions

File tree

README.md

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ LLM coding agents waste **80-95% of context tokens** on irrelevant tool output.
1717

1818
Squeez trains small models to identify and extract only the lines that matter for the task at hand — compressing tool output by ~86% on average.
1919

20-
Two approaches are available:
20+
Three approaches are available:
2121

2222
- **Generative** (Qwen 3.5 2B + LoRA) — high-quality extraction via XML-wrapped verbatim output
23-
- **Encoder** (mmBERT 307M) — fast line-level binary classification, sliding window over long outputs
23+
- **Pooled encoder** (ModernBERT / ettin) — single-pass encoder with line-level mean-pool classification, works with any HuggingFace encoder
24+
- **Token encoder** (mmBERT) — per-token binary classification with sliding window
2425

2526
## Example
2627

@@ -144,7 +145,7 @@ For generative model training (Qwen + LoRA):
144145
pip install -r requirements-train.txt
145146
```
146147

147-
For encoder model training (mmBERT):
148+
For encoder model training:
148149

149150
```bash
150151
pip install -r requirements-encoder.txt
@@ -176,8 +177,8 @@ extractor = ToolOutputExtractor()
176177
# Or load a generative model locally
177178
extractor = ToolOutputExtractor(model_path="./output/squeez_qwen")
178179

179-
# Or load an encoder model (auto-detected from config.json)
180-
extractor = ToolOutputExtractor(model_path="./output/squeez_encoder")
180+
# Or load an encoder model (pooled or token, auto-detected from config.json)
181+
extractor = ToolOutputExtractor(model_path="./output/squeez_pooled")
181182

182183
# Or connect to a server explicitly
183184
extractor = ToolOutputExtractor(base_url="http://localhost:8000/v1", model_name="squeez")
@@ -259,37 +260,75 @@ squeez train \
259260

260261
Default: Qwen 3.5 2B with LoRA (r=16, alpha=32). See `configs/default.yaml` for all hyperparameters.
261262

262-
### 2b. Train encoder model (mmBERT)
263+
### 2b. Train pooled encoder (recommended)
263264

264265
```bash
265-
# Prepare encoder-format data from the downloaded splits
266-
python scripts/prepare_encoder_data.py --data-dir data
266+
python -m squeez.encoder.train \
267+
--classifier-type pooled \
268+
--train-file data/encoder_train.jsonl \
269+
--eval-file data/encoder_dev.jsonl \
270+
--base-model answerdotai/ModernBERT-base \
271+
--output-dir output/squeez_pooled \
272+
--batch-size 96 \
273+
--gradient-accumulation-steps 2 \
274+
--max-length 4096 \
275+
--learning-rate 2e-5 \
276+
--num-epochs 4
277+
```
278+
279+
The pooled encoder runs a single forward pass over the full input, mean-pools hidden states per line, and classifies each line as relevant/irrelevant. Works with any HuggingFace encoder model (ModernBERT, ettin, DeBERTa, etc.) and uses sliding windows for outputs longer than `--max-length`.
280+
281+
After training, the model can be loaded standalone without squeez installed:
282+
283+
```python
284+
from transformers import AutoModel, AutoTokenizer
285+
286+
model = AutoModel.from_pretrained("output/squeez_pooled", trust_remote_code=True)
287+
tokenizer = AutoTokenizer.from_pretrained("output/squeez_pooled")
267288

268-
# Train the encoder
289+
result = model.process(
290+
task="Find the traceback that shows the import error",
291+
tool_output=open("output.log").read(),
292+
tokenizer=tokenizer,
293+
)
294+
print(result["highlighted_lines"])
295+
```
296+
297+
### 2c. Train token encoder (alternative)
298+
299+
```bash
269300
python -m squeez.encoder.train \
301+
--classifier-type token \
270302
--train-file data/encoder_train.jsonl \
271303
--eval-file data/encoder_dev.jsonl \
272-
--base-model jhu-clsp/mmBERT-base \
304+
--base-model answerdotai/ModernBERT-base \
273305
--output-dir output/squeez_encoder
274306
```
275307

276-
The encoder is a 307M parameter mmBERT with a token classification head. It classifies each line as relevant/irrelevant and uses sliding windows to handle outputs longer than the 8K context.
277-
278308
### 3. Evaluate
279309

280310
```bash
281-
# Generative model
311+
# Generative model (local)
282312
squeez eval \
283313
--extractor-model output/squeez_qwen \
284-
--eval-file data/test.jsonl
314+
--eval-file data/test.jsonl \
315+
--max-new-tokens 4096
316+
317+
# Generative model (remote vLLM server)
318+
squeez eval \
319+
--server-url http://localhost:8000/v1 \
320+
--eval-file data/test.jsonl \
321+
--max-new-tokens 4096 \
322+
--request-concurrency 8
285323

286-
# Encoder model
324+
# Encoder model (pooled or token, auto-detected)
287325
python -m squeez.encoder.evaluate \
288-
--model-path output/squeez_encoder \
289-
--eval-file data/encoder_test.jsonl
326+
--model-path output/squeez_pooled \
327+
--eval-file data/encoder_test.jsonl \
328+
--examples-output eval_examples_pooled.json
290329
```
291330

292-
Both produce the same metrics format (strict and fuzzy line overlap, ROUGE-L, compression ratio) for direct comparison.
331+
All produce the same metrics format (strict and fuzzy line overlap, ROUGE-L, compression ratio) for direct comparison.
293332

294333
## Dataset
295334

TRAINING.md

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Training & Evaluation Commands
2+
3+
## 1. Download data
4+
5+
```bash
6+
python scripts/download_data.py
7+
```
8+
9+
Downloads from [HuggingFace](https://huggingface.co/datasets/KRLabsOrg/tool-output-extraction-swebench) to `data/`:
10+
- `train.jsonl`, `dev.jsonl`, `test.jsonl` (generative format)
11+
- `encoder_train.jsonl`, `encoder_dev.jsonl`, `encoder_test.jsonl` (encoder format)
12+
- `canonical_train.jsonl`, `canonical_dev.jsonl`, `canonical_test.jsonl` (span-based ground truth)
13+
14+
## 2. Train
15+
16+
### Pooled encoder (recommended)
17+
18+
Single-pass encoder + line-level mean-pool classifier. Works with any HuggingFace encoder.
19+
20+
```bash
21+
# ModernBERT-base on A100 (~75 min)
22+
python -m squeez.encoder.train \
23+
--classifier-type pooled \
24+
--train-file data/encoder_train.jsonl \
25+
--eval-file data/encoder_dev.jsonl \
26+
--base-model answerdotai/ModernBERT-base \
27+
--output-dir output/squeez_pooled \
28+
--batch-size 96 \
29+
--gradient-accumulation-steps 2 \
30+
--max-length 4096 \
31+
--learning-rate 2e-5 \
32+
--num-epochs 4
33+
34+
# ModernBERT-large (higher capacity, slower)
35+
python -m squeez.encoder.train \
36+
--classifier-type pooled \
37+
--train-file data/encoder_train.jsonl \
38+
--eval-file data/encoder_dev.jsonl \
39+
--base-model answerdotai/ModernBERT-large \
40+
--output-dir output/squeez_pooled_large \
41+
--batch-size 24 \
42+
--gradient-accumulation-steps 4 \
43+
--max-length 4096 \
44+
--learning-rate 2e-5 \
45+
--num-epochs 4
46+
47+
# Other encoder models work too
48+
# --base-model jhu-clsp/ettin-encoder-32m
49+
# --base-model microsoft/deberta-v3-large
50+
# --base-model BAAI/bge-large-en-v1.5
51+
```
52+
53+
### Token encoder
54+
55+
Per-token binary classification (alternative approach).
56+
57+
```bash
58+
python -m squeez.encoder.train \
59+
--classifier-type token \
60+
--train-file data/encoder_train.jsonl \
61+
--eval-file data/encoder_dev.jsonl \
62+
--base-model answerdotai/ModernBERT-base \
63+
--output-dir output/squeez_encoder \
64+
--batch-size 2 \
65+
--max-length 8192
66+
```
67+
68+
### Generative model (Qwen + LoRA)
69+
70+
```bash
71+
squeez train \
72+
--train-file data/train.jsonl \
73+
--eval-file data/dev.jsonl \
74+
--output-dir output/squeez_qwen
75+
```
76+
77+
To merge LoRA weights and serve:
78+
79+
```bash
80+
# Merge
81+
python scripts/merge_lora.py \
82+
--checkpoint output/squeez_qwen/checkpoint-500 \
83+
--output output/squeez_qwen_merged
84+
85+
# Serve with vLLM
86+
vllm serve output/squeez_qwen_merged \
87+
--max-model-len 32768 \
88+
--trust-remote-code
89+
```
90+
91+
## 3. Evaluate
92+
93+
### Encoder (pooled or token, auto-detected)
94+
95+
```bash
96+
python -m squeez.encoder.evaluate \
97+
--model-path output/squeez_pooled \
98+
--eval-file data/encoder_test.jsonl \
99+
--examples-output eval_examples_pooled.json
100+
```
101+
102+
Optional flags:
103+
- `--threshold 0.5` — relevance probability cutoff (default 0.5)
104+
- `--max-samples 100` — evaluate on a subset
105+
106+
### Generative (local model)
107+
108+
```bash
109+
squeez eval \
110+
--extractor-model output/squeez_qwen_merged \
111+
--eval-file data/test.jsonl \
112+
--max-new-tokens 4096 \
113+
--examples-output eval_examples.json
114+
```
115+
116+
### Generative (remote vLLM server)
117+
118+
```bash
119+
squeez eval \
120+
--server-url http://localhost:8000/v1 \
121+
--eval-file data/test.jsonl \
122+
--max-new-tokens 4096 \
123+
--request-concurrency 8 \
124+
--examples-output eval_examples.json
125+
```
126+
127+
## 4. Standalone inference (no squeez install)
128+
129+
After training the pooled encoder, the output directory contains `modeling_squeez_pooled.py` so `AutoModel` works directly:
130+
131+
```python
132+
from transformers import AutoModel, AutoTokenizer
133+
134+
model = AutoModel.from_pretrained("output/squeez_pooled", trust_remote_code=True)
135+
tokenizer = AutoTokenizer.from_pretrained("output/squeez_pooled")
136+
137+
result = model.process(
138+
task="Find the traceback that shows the import error",
139+
tool_output=open("output.log").read(),
140+
tokenizer=tokenizer,
141+
threshold=0.5,
142+
return_line_probabilities=True,
143+
)
144+
print(result["highlighted_lines"])
145+
print(result["highlighted_indices"])
146+
```
147+
148+
## 5. Upload to HuggingFace
149+
150+
### Dataset
151+
152+
```bash
153+
python scripts/upload_to_hf.py --data-dir data/v3
154+
```
155+
156+
### Model
157+
158+
Push the trained model directory (includes `modeling_squeez_pooled.py` for standalone loading):
159+
160+
```python
161+
from huggingface_hub import HfApi
162+
api = HfApi()
163+
api.upload_folder(
164+
folder_path="output/squeez_pooled",
165+
repo_id="KRLabsOrg/squeez-pooled-modernbert",
166+
repo_type="model",
167+
)
168+
```

squeez/encoder/modeling_squeez_pooled.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,7 @@ def _pool_lines(
232232
if max_lines == 0:
233233
max_lines = 1
234234

235-
flat_idx = (
236-
torch.arange(batch_size, device=device).unsqueeze(1) * max_lines
237-
+ segment_ids
238-
)
235+
flat_idx = torch.arange(batch_size, device=device).unsqueeze(1) * max_lines + segment_ids
239236
flat_idx = flat_idx * valid_token.long()
240237

241238
pooled_flat = torch.zeros(batch_size * max_lines, hidden, device=device)

squeez/encoder/sentence.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,7 @@ def _pool_lines(
216216
# Use scatter_add to sum hidden states per (batch, segment)
217217
# Flatten to [batch * max_lines] buckets
218218
flat_idx = (
219-
torch.arange(batch_size, device=device).unsqueeze(1) * max_lines
220-
+ segment_ids
219+
torch.arange(batch_size, device=device).unsqueeze(1) * max_lines + segment_ids
221220
) # [batch, seq_len]
222221

223222
# Zero out invalid positions
@@ -559,32 +558,24 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
559558

560559
def collate_pooled_lines(batch: list[dict]) -> dict[str, torch.Tensor]:
561560
"""Custom collator: pad input_ids and line_labels separately."""
561+
batch_size = len(batch)
562562
max_seq_len = max(b["input_ids"].shape[0] for b in batch)
563563
max_lines = max(b["line_labels"].shape[0] for b in batch)
564564

565-
input_ids = []
566-
attention_mask = []
567-
line_labels = []
565+
# Pre-allocate padded tensors
566+
input_ids = torch.zeros(batch_size, max_seq_len, dtype=torch.long)
567+
attention_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.long)
568+
line_labels = torch.full((batch_size, max_lines), -100, dtype=torch.long)
568569

569-
for b in batch:
570+
for i, b in enumerate(batch):
570571
seq_len = b["input_ids"].shape[0]
571572
n_lines = b["line_labels"].shape[0]
572-
573-
# Pad sequences
574-
pad_len = max_seq_len - seq_len
575-
input_ids.append(torch.cat([b["input_ids"], torch.zeros(pad_len, dtype=torch.long)]))
576-
attention_mask.append(
577-
torch.cat([b["attention_mask"], torch.zeros(pad_len, dtype=torch.long)])
578-
)
579-
580-
# Pad line labels with -100
581-
label_pad = max_lines - n_lines
582-
line_labels.append(
583-
torch.cat([b["line_labels"], torch.full((label_pad,), -100, dtype=torch.long)])
584-
)
573+
input_ids[i, :seq_len] = b["input_ids"]
574+
attention_mask[i, :seq_len] = b["attention_mask"]
575+
line_labels[i, :n_lines] = b["line_labels"]
585576

586577
return {
587-
"input_ids": torch.stack(input_ids),
588-
"attention_mask": torch.stack(attention_mask),
589-
"line_labels": torch.stack(line_labels),
578+
"input_ids": input_ids,
579+
"attention_mask": attention_mask,
580+
"line_labels": line_labels,
590581
}

0 commit comments

Comments
 (0)