Skip to content

Commit b884e89

Browse files
committed
Eval nan bug
1 parent 2c0b3ed commit b884e89

5 files changed

Lines changed: 141 additions & 30 deletions

File tree

squeez/encoder/chunking.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Helpers for splitting pathological long lines into token chunks.
2+
3+
The encoder task stays line-level by default. Only lines whose tokenized
4+
length exceeds the configured per-line budget are split into chunked
5+
"pseudo-lines". Training assigns the original line label to every chunk.
6+
Inference aggregates chunk scores back to the original line index.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from transformers import PreTrainedTokenizer
12+
13+
14+
def encode_text(
15+
tokenizer: PreTrainedTokenizer,
16+
text: str,
17+
truncation: bool = False,
18+
max_length: int | None = None,
19+
) -> list[int]:
20+
"""Tokenize a single text span without special tokens or warning spam."""
21+
encoded = tokenizer(
22+
text,
23+
add_special_tokens=False,
24+
truncation=truncation,
25+
max_length=max_length,
26+
return_attention_mask=False,
27+
return_token_type_ids=False,
28+
verbose=False,
29+
)
30+
input_ids = encoded["input_ids"]
31+
return input_ids if isinstance(input_ids, list) else list(input_ids)
32+
33+
34+
def chunk_output_lines(
35+
tokenizer: PreTrainedTokenizer,
36+
output_lines: list[str],
37+
max_tokens_per_chunk: int,
38+
) -> tuple[list[list[int]], list[int]]:
39+
"""Tokenize output lines, splitting only oversized lines into chunks.
40+
41+
Returns:
42+
chunk_token_ids: token ids for each pseudo-line/chunk
43+
chunk_to_line: mapping from chunk index back to original line index
44+
"""
45+
chunk_token_ids: list[list[int]] = []
46+
chunk_to_line: list[int] = []
47+
48+
for line_idx, line in enumerate(output_lines):
49+
token_ids = encode_text(tokenizer, line)
50+
if not token_ids:
51+
continue
52+
53+
if len(token_ids) <= max_tokens_per_chunk:
54+
chunk_token_ids.append(token_ids)
55+
chunk_to_line.append(line_idx)
56+
continue
57+
58+
for start in range(0, len(token_ids), max_tokens_per_chunk):
59+
chunk = token_ids[start : start + max_tokens_per_chunk]
60+
if chunk:
61+
chunk_token_ids.append(chunk)
62+
chunk_to_line.append(line_idx)
63+
64+
return chunk_token_ids, chunk_to_line

squeez/encoder/dataset.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.utils.data import Dataset
2424
from transformers import PreTrainedTokenizer
2525

26+
from squeez.encoder.chunking import chunk_output_lines, encode_text
2627
from squeez.encoder.model import LINE_SEP_TOKEN
2728

2829
logger = logging.getLogger(__name__)
@@ -106,6 +107,7 @@ def __init__(
106107
# covering one window of lines that fits within max_length.
107108
self._windows: list[tuple[list[int], list[list[int]], list[bool]]] = []
108109
n_expanded = 0
110+
n_skipped_empty = 0
109111

110112
for sample in raw_samples:
111113
task = sample["task"]
@@ -116,23 +118,20 @@ def __init__(
116118
line_labels = _match_lines(output_lines, relevant_lines)
117119

118120
# Tokenize task, cap at half of max_length
119-
task_ids = tokenizer.encode(
121+
task_ids = encode_text(
122+
tokenizer,
120123
task,
121-
add_special_tokens=False,
122124
truncation=True,
123125
max_length=self._max_task_tokens,
124126
)
125127

126-
# Tokenize each line
127-
line_token_ids = [
128-
tokenizer.encode(
129-
ln,
130-
add_special_tokens=False,
131-
truncation=True,
132-
max_length=self._max_line_tokens,
133-
)
134-
for ln in output_lines
135-
]
128+
# Tokenize each line, chunking only pathological long lines.
129+
line_token_ids, chunk_to_line = chunk_output_lines(
130+
tokenizer,
131+
output_lines,
132+
max_tokens_per_chunk=self._max_line_tokens,
133+
)
134+
chunk_labels = [line_labels[line_idx] for line_idx in chunk_to_line]
136135

137136
# overhead = [CLS] + task + [SEP] + ... + [SEP]
138137
prefix_len = 1 + len(task_ids) + 1
@@ -142,11 +141,15 @@ def __init__(
142141
windows = self._build_windows(line_token_ids, budget)
143142

144143
for start, end in windows:
144+
window_line_token_ids = line_token_ids[start:end]
145+
if not any(window_line_token_ids):
146+
n_skipped_empty += 1
147+
continue
145148
self._windows.append(
146149
(
147150
task_ids,
148-
line_token_ids[start:end],
149-
line_labels[start:end],
151+
window_line_token_ids,
152+
chunk_labels[start:end],
150153
)
151154
)
152155

@@ -156,7 +159,8 @@ def __init__(
156159
logger.info(
157160
f"Loaded {len(raw_samples)} samples from {data_path} → "
158161
f"{len(self._windows)} windows "
159-
f"({n_expanded} extra from sliding, max_length={max_length})"
162+
f"({n_expanded} extra from sliding, {n_skipped_empty} empty windows skipped, "
163+
f"max_length={max_length})"
160164
)
161165

162166
# ------------------------------------------------------------------

squeez/encoder/model.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from transformers import AutoConfig, AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel
1919
from transformers.modeling_outputs import TokenClassifierOutput
2020

21+
from squeez.encoder.chunking import chunk_output_lines, encode_text
22+
2123
logger = logging.getLogger(__name__)
2224

2325
LINE_SEP_TOKEN = "[LINE_SEP]"
@@ -163,9 +165,9 @@ def extract(
163165
sep_id = tokenizer.sep_token_id
164166

165167
# Tokenize task prefix (will be reused for every window)
166-
task_ids = tokenizer.encode(
168+
task_ids = encode_text(
169+
tokenizer,
167170
task,
168-
add_special_tokens=False,
169171
truncation=True,
170172
max_length=max(max_len - 3 - _MIN_LINE_BUDGET, 0),
171173
)
@@ -175,16 +177,14 @@ def extract(
175177
suffix_len = 1 # final SEP
176178
budget = max_len - prefix_len - suffix_len
177179

178-
# Tokenize each line
179-
line_token_ids: list[list[int]] = []
180-
for line in lines:
181-
ids = tokenizer.encode(
182-
line,
183-
add_special_tokens=False,
184-
truncation=True,
185-
max_length=max(max_len - 4, 1),
186-
)
187-
line_token_ids.append(ids)
180+
# Tokenize lines, chunking only pathological long lines.
181+
line_token_ids, chunk_to_line = chunk_output_lines(
182+
tokenizer,
183+
lines,
184+
max_tokens_per_chunk=max(max_len - 4, 1),
185+
)
186+
if not line_token_ids:
187+
return []
188188

189189
# Build windows
190190
windows = self._build_windows(line_token_ids, budget, window_overlap)
@@ -199,8 +199,9 @@ def extract(
199199

200200
scores = self._predict_window(input_ids, attention_mask, line_sep_positions, sep_id)
201201
for i, score in enumerate(scores):
202-
global_idx = start_idx + i
203-
line_scores[global_idx] = max(line_scores[global_idx], score)
202+
chunk_idx = start_idx + i
203+
line_idx = chunk_to_line[chunk_idx]
204+
line_scores[line_idx] = max(line_scores[line_idx], score)
204205

205206
return [line for line, score in zip(lines, line_scores) if score >= threshold]
206207

squeez/encoder/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def train(
9898
data_collator = DataCollatorForTokenClassification(
9999
tokenizer=tokenizer,
100100
padding=True,
101-
max_length=max_length,
102101
)
103102

104103
# Training arguments

tests/test_encoder_chunking.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from squeez.encoder.chunking import chunk_output_lines
2+
3+
4+
class FakeTokenizer:
5+
def __call__(
6+
self,
7+
text,
8+
add_special_tokens=False,
9+
truncation=False,
10+
max_length=None,
11+
return_attention_mask=False,
12+
return_token_type_ids=False,
13+
verbose=False,
14+
):
15+
# Tokenize on whitespace for predictable chunk sizes in tests.
16+
tokens = [len(part) for part in text.split() if part]
17+
if truncation and max_length is not None:
18+
tokens = tokens[:max_length]
19+
return {"input_ids": tokens}
20+
21+
22+
def test_chunk_output_lines_splits_only_overlong_lines():
23+
tokenizer = FakeTokenizer()
24+
lines = [
25+
"short line",
26+
"a b c d e f g",
27+
"tiny",
28+
]
29+
30+
chunks, chunk_to_line = chunk_output_lines(
31+
tokenizer,
32+
lines,
33+
max_tokens_per_chunk=3,
34+
)
35+
36+
assert chunks == [
37+
[5, 4], # short line
38+
[1, 1, 1], # first chunk of long line
39+
[1, 1, 1], # second chunk of long line
40+
[1], # third chunk of long line
41+
[4], # tiny
42+
]
43+
assert chunk_to_line == [0, 1, 1, 1, 2]

0 commit comments

Comments
 (0)