2323from torch .utils .data import Dataset
2424from transformers import PreTrainedTokenizer
2525
26+ from squeez .encoder .chunking import chunk_output_lines , encode_text
2627from squeez .encoder .model import LINE_SEP_TOKEN
2728
2829logger = logging .getLogger (__name__ )
@@ -106,6 +107,7 @@ def __init__(
106107 # covering one window of lines that fits within max_length.
107108 self ._windows : list [tuple [list [int ], list [list [int ]], list [bool ]]] = []
108109 n_expanded = 0
110+ n_skipped_empty = 0
109111
110112 for sample in raw_samples :
111113 task = sample ["task" ]
@@ -116,23 +118,20 @@ def __init__(
116118 line_labels = _match_lines (output_lines , relevant_lines )
117119
118120 # Tokenize task, cap at half of max_length
119- task_ids = tokenizer .encode (
121+ task_ids = encode_text (
122+ tokenizer ,
120123 task ,
121- add_special_tokens = False ,
122124 truncation = True ,
123125 max_length = self ._max_task_tokens ,
124126 )
125127
126- # Tokenize each line
127- line_token_ids = [
128- tokenizer .encode (
129- ln ,
130- add_special_tokens = False ,
131- truncation = True ,
132- max_length = self ._max_line_tokens ,
133- )
134- for ln in output_lines
135- ]
128+ # Tokenize each line, chunking only pathological long lines.
129+ line_token_ids , chunk_to_line = chunk_output_lines (
130+ tokenizer ,
131+ output_lines ,
132+ max_tokens_per_chunk = self ._max_line_tokens ,
133+ )
134+ chunk_labels = [line_labels [line_idx ] for line_idx in chunk_to_line ]
136135
137136 # overhead = [CLS] + task + [SEP] + ... + [SEP]
138137 prefix_len = 1 + len (task_ids ) + 1
@@ -142,11 +141,15 @@ def __init__(
142141 windows = self ._build_windows (line_token_ids , budget )
143142
144143 for start , end in windows :
144+ window_line_token_ids = line_token_ids [start :end ]
145+ if not any (window_line_token_ids ):
146+ n_skipped_empty += 1
147+ continue
145148 self ._windows .append (
146149 (
147150 task_ids ,
148- line_token_ids [ start : end ] ,
149- line_labels [start :end ],
151+ window_line_token_ids ,
152+ chunk_labels [start :end ],
150153 )
151154 )
152155
@@ -156,7 +159,8 @@ def __init__(
156159 logger .info (
157160 f"Loaded { len (raw_samples )} samples from { data_path } → "
158161 f"{ len (self ._windows )} windows "
159- f"({ n_expanded } extra from sliding, max_length={ max_length } )"
162+ f"({ n_expanded } extra from sliding, { n_skipped_empty } empty windows skipped, "
163+ f"max_length={ max_length } )"
160164 )
161165
162166 # ------------------------------------------------------------------
0 commit comments