Skip to content

Commit bc25f7d

Browse files
committed
Faster training
1 parent 8356ec9 commit bc25f7d

2 files changed

Lines changed: 154 additions & 101 deletions

File tree

squeez/encoder/modeling_squeez_pooled.py

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -194,62 +194,78 @@ def _pool_lines(
194194
line_sep_id: int,
195195
sep_token_id: int,
196196
) -> tuple[torch.Tensor, torch.Tensor]:
197-
"""Mean-pool hidden states per line for each sample in the batch."""
197+
"""Mean-pool hidden states per line for each sample in the batch.
198+
199+
Vectorized: uses torch ops to find boundaries and scatter_add for pooling.
200+
"""
198201
batch_size, seq_len, hidden = hidden_states.shape
199202
device = hidden_states.device
200203

201-
all_pooled: list[list[torch.Tensor]] = []
202-
max_lines = 0
204+
is_sep = input_ids == sep_token_id
205+
is_line_sep = input_ids == line_sep_id
203206

204-
for b in range(batch_size):
205-
ids = input_ids[b].tolist()
206-
207-
first_sep = -1
208-
for i, t in enumerate(ids):
209-
if i > 0 and t == sep_token_id:
210-
first_sep = i
211-
break
212-
if first_sep < 0:
213-
all_pooled.append([])
214-
continue
215-
216-
final_sep = seq_len - 1
217-
for i in range(seq_len - 1, first_sep, -1):
218-
if ids[i] == sep_token_id:
219-
final_sep = i
220-
break
221-
222-
sep_positions = []
223-
for i in range(first_sep + 1, final_sep):
224-
if ids[i] == line_sep_id:
225-
sep_positions.append(i)
226-
227-
boundaries = [first_sep + 1] + sep_positions + [final_sep]
228-
line_vectors: list[torch.Tensor] = []
229-
230-
for i in range(len(boundaries) - 1):
231-
start = boundaries[i]
232-
end = boundaries[i + 1]
233-
if i > 0:
234-
start += 1
235-
if start >= end:
236-
line_vectors.append(torch.zeros(hidden, device=device))
237-
continue
238-
line_vectors.append(hidden_states[b, start:end].mean(dim=0))
239-
240-
all_pooled.append(line_vectors)
241-
max_lines = max(max_lines, len(line_vectors))
207+
is_sep_no_cls = is_sep.clone()
208+
is_sep_no_cls[:, 0] = False
209+
210+
has_sep = is_sep_no_cls.any(dim=1)
211+
first_sep = is_sep_no_cls.float().argmax(dim=1)
212+
213+
pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
214+
215+
is_sep_flipped = is_sep.flip(dims=[1])
216+
last_sep_from_end = is_sep_flipped.float().argmax(dim=1)
217+
final_sep = seq_len - 1 - last_sep_from_end
242218

219+
in_lines = (pos > first_sep.unsqueeze(1)) & (pos < final_sep.unsqueeze(1))
220+
in_lines = in_lines & has_sep.unsqueeze(1)
221+
222+
line_sep_in_region = is_line_sep & in_lines
223+
segment_ids = line_sep_in_region.long().cumsum(dim=1)
224+
225+
valid_token = in_lines & ~is_line_sep & ~is_sep
226+
valid_token = valid_token & (pos != first_sep.unsqueeze(1))
227+
228+
n_lines_per_sample = segment_ids.max(dim=1).values + 1
229+
n_lines_per_sample = n_lines_per_sample.clamp(min=0)
230+
n_lines_per_sample[~has_sep] = 0
231+
max_lines = int(n_lines_per_sample.max().item())
243232
if max_lines == 0:
244233
max_lines = 1
245234

246-
pooled = torch.zeros(batch_size, max_lines, hidden, device=device)
247-
line_mask = torch.zeros(batch_size, max_lines, dtype=torch.bool, device=device)
235+
flat_idx = (
236+
torch.arange(batch_size, device=device).unsqueeze(1) * max_lines
237+
+ segment_ids
238+
)
239+
flat_idx = flat_idx * valid_token.long()
248240

249-
for b, vectors in enumerate(all_pooled):
250-
for i, vec in enumerate(vectors):
251-
pooled[b, i] = vec
252-
line_mask[b, i] = True
241+
pooled_flat = torch.zeros(batch_size * max_lines, hidden, device=device)
242+
counts_flat = torch.zeros(batch_size * max_lines, device=device)
243+
244+
flat_idx_expanded = flat_idx.view(-1).unsqueeze(1).expand(-1, hidden)
245+
valid_flat = valid_token.view(-1)
246+
247+
hidden_flat = hidden_states.view(-1, hidden)
248+
249+
valid_hidden = hidden_flat[valid_flat]
250+
valid_idx = flat_idx_expanded[valid_flat]
251+
252+
pooled_flat.scatter_add_(0, valid_idx, valid_hidden)
253+
counts_flat.scatter_add_(
254+
0,
255+
flat_idx.view(-1)[valid_flat],
256+
torch.ones(valid_flat.sum(), device=device),
257+
)
258+
259+
counts_flat = counts_flat.clamp(min=1)
260+
pooled_flat = pooled_flat / counts_flat.unsqueeze(1)
261+
262+
pooled = pooled_flat.view(batch_size, max_lines, hidden)
263+
264+
line_mask = torch.zeros(batch_size, max_lines, dtype=torch.bool, device=device)
265+
for b in range(batch_size):
266+
n = int(n_lines_per_sample[b].item())
267+
if n > 0:
268+
line_mask[b, :n] = True
253269

254270
return pooled, line_mask
255271

squeez/encoder/sentence.py

Lines changed: 91 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -152,74 +152,111 @@ def _pool_lines(
152152
) -> tuple[torch.Tensor, torch.Tensor]:
153153
"""Mean-pool hidden states per line for each sample in the batch.
154154
155+
Vectorized: uses torch ops to find boundaries and scatter_add for pooling.
156+
155157
Returns:
156158
pooled: [batch, max_lines, hidden] — pooled line representations
157159
line_mask: [batch, max_lines] — True for real lines, False for padding
158160
"""
159161
batch_size, seq_len, hidden = hidden_states.shape
160162
device = hidden_states.device
161163

162-
# Find line boundaries per sample
163-
all_pooled: list[list[torch.Tensor]] = []
164-
max_lines = 0
164+
# Masks for separator tokens [batch, seq_len]
165+
is_sep = input_ids == sep_token_id
166+
is_line_sep = input_ids == line_sep_id
167+
168+
# Find first SEP per sample (skip position 0 which is CLS)
169+
# Set position 0 to False to avoid matching CLS
170+
is_sep_no_cls = is_sep.clone()
171+
is_sep_no_cls[:, 0] = False
172+
173+
# first_sep: first SEP after CLS (end of task)
174+
# Use argmax on the mask — returns first True position
175+
has_sep = is_sep_no_cls.any(dim=1)
176+
first_sep = is_sep_no_cls.float().argmax(dim=1) # [batch]
177+
178+
# Build a segment ID for each token: which line does it belong to?
179+
# Tokens before first_sep+1 get segment -1 (task region, excluded)
180+
# LINE_SEP tokens increment the segment counter
181+
# Tokens at/after final SEP get segment -1
182+
183+
# Create position indices [batch, seq_len]
184+
pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
185+
186+
# Mask for tokens in the lines region (after first SEP, before padding/final SEP)
187+
# We need to find the last SEP per sample
188+
# Flip and argmax to find last SEP
189+
is_sep_flipped = is_sep.flip(dims=[1])
190+
last_sep_from_end = is_sep_flipped.float().argmax(dim=1) # [batch]
191+
final_sep = seq_len - 1 - last_sep_from_end # [batch]
192+
193+
# Lines region: first_sep+1 <= pos < final_sep
194+
in_lines = (pos > first_sep.unsqueeze(1)) & (pos < final_sep.unsqueeze(1))
195+
in_lines = in_lines & has_sep.unsqueeze(1)
196+
197+
# Compute segment IDs via cumsum of LINE_SEP tokens in the lines region
198+
# Each LINE_SEP increments the line counter
199+
line_sep_in_region = is_line_sep & in_lines
200+
segment_ids = line_sep_in_region.long().cumsum(dim=1) # [batch, seq_len]
201+
202+
# Exclude tokens outside lines region and LINE_SEP tokens themselves
203+
valid_token = in_lines & ~is_line_sep & ~is_sep
204+
# Also exclude the SEP tokens that bound the region
205+
valid_token = valid_token & (pos != first_sep.unsqueeze(1))
206+
207+
# Number of lines per sample
208+
n_lines_per_sample = segment_ids.max(dim=1).values + 1 # [batch]
209+
n_lines_per_sample = n_lines_per_sample.clamp(min=0)
210+
# For samples with no valid tokens, set to 0
211+
n_lines_per_sample[~has_sep] = 0
212+
max_lines = int(n_lines_per_sample.max().item())
213+
if max_lines == 0:
214+
max_lines = 1
165215

166-
for b in range(batch_size):
167-
ids = input_ids[b].tolist()
216+
# Use scatter_add to sum hidden states per (batch, segment)
217+
# Flatten to [batch * max_lines] buckets
218+
flat_idx = (
219+
torch.arange(batch_size, device=device).unsqueeze(1) * max_lines
220+
+ segment_ids
221+
) # [batch, seq_len]
168222

169-
# Find first SEP (end of task) — lines start after it
170-
first_sep = -1
171-
for i, t in enumerate(ids):
172-
if i > 0 and t == sep_token_id:
173-
first_sep = i
174-
break
175-
if first_sep < 0:
176-
all_pooled.append([])
177-
continue
178-
179-
# Collect line boundaries: segments between LINE_SEP tokens
180-
# Lines region: first_sep+1 ... final_sep-1
181-
# Find final SEP (end of lines) — last non-pad token
182-
final_sep = seq_len - 1
183-
for i in range(seq_len - 1, first_sep, -1):
184-
if ids[i] == sep_token_id:
185-
final_sep = i
186-
break
223+
# Zero out invalid positions
224+
flat_idx = flat_idx * valid_token.long() # invalid -> bucket 0 (will be masked)
187225

188-
# Collect LINE_SEP positions within lines region
189-
sep_positions = []
190-
for i in range(first_sep + 1, final_sep):
191-
if ids[i] == line_sep_id:
192-
sep_positions.append(i)
193-
194-
# Build line segments
195-
boundaries = [first_sep + 1] + sep_positions + [final_sep]
196-
line_vectors: list[torch.Tensor] = []
197-
198-
for i in range(len(boundaries) - 1):
199-
start = boundaries[i]
200-
end = boundaries[i + 1]
201-
# Skip the LINE_SEP token itself
202-
if i > 0:
203-
start += 1
204-
if start >= end:
205-
line_vectors.append(torch.zeros(hidden, device=device))
206-
continue
207-
line_vectors.append(hidden_states[b, start:end].mean(dim=0))
226+
# Sum hidden states into buckets
227+
pooled_flat = torch.zeros(batch_size * max_lines, hidden, device=device)
228+
counts_flat = torch.zeros(batch_size * max_lines, device=device)
208229

209-
all_pooled.append(line_vectors)
210-
max_lines = max(max_lines, len(line_vectors))
230+
# Expand flat_idx for hidden dim
231+
flat_idx_expanded = flat_idx.view(-1).unsqueeze(1).expand(-1, hidden)
232+
valid_flat = valid_token.view(-1)
211233

212-
if max_lines == 0:
213-
max_lines = 1
234+
hidden_flat = hidden_states.view(-1, hidden)
214235

215-
# Pad to [batch, max_lines, hidden]
216-
pooled = torch.zeros(batch_size, max_lines, hidden, device=device)
217-
line_mask = torch.zeros(batch_size, max_lines, dtype=torch.bool, device=device)
236+
# Only scatter valid tokens
237+
valid_hidden = hidden_flat[valid_flat]
238+
valid_idx = flat_idx_expanded[valid_flat]
239+
240+
pooled_flat.scatter_add_(0, valid_idx, valid_hidden)
241+
counts_flat.scatter_add_(
242+
0,
243+
flat_idx.view(-1)[valid_flat],
244+
torch.ones(valid_flat.sum(), device=device),
245+
)
218246

219-
for b, vectors in enumerate(all_pooled):
220-
for i, vec in enumerate(vectors):
221-
pooled[b, i] = vec
222-
line_mask[b, i] = True
247+
# Mean pool: divide by counts
248+
counts_flat = counts_flat.clamp(min=1)
249+
pooled_flat = pooled_flat / counts_flat.unsqueeze(1)
250+
251+
# Reshape to [batch, max_lines, hidden]
252+
pooled = pooled_flat.view(batch_size, max_lines, hidden)
253+
254+
# Line mask: True where we have actual lines
255+
line_mask = torch.zeros(batch_size, max_lines, dtype=torch.bool, device=device)
256+
for b in range(batch_size):
257+
n = int(n_lines_per_sample[b].item())
258+
if n > 0:
259+
line_mask[b, :n] = True
223260

224261
return pooled, line_mask
225262

0 commit comments

Comments
 (0)