@@ -152,74 +152,111 @@ def _pool_lines(
152152 ) -> tuple [torch .Tensor , torch .Tensor ]:
153153 """Mean-pool hidden states per line for each sample in the batch.
154154
155+ Vectorized: uses torch ops to find boundaries and scatter_add for pooling.
156+
155157 Returns:
156158 pooled: [batch, max_lines, hidden] — pooled line representations
157159 line_mask: [batch, max_lines] — True for real lines, False for padding
158160 """
159161 batch_size , seq_len , hidden = hidden_states .shape
160162 device = hidden_states .device
161163
162- # Find line boundaries per sample
163- all_pooled : list [list [torch .Tensor ]] = []
164- max_lines = 0
164+ # Masks for separator tokens [batch, seq_len]
165+ is_sep = input_ids == sep_token_id
166+ is_line_sep = input_ids == line_sep_id
167+
168+ # Find first SEP per sample (skip position 0 which is CLS)
169+ # Set position 0 to False to avoid matching CLS
170+ is_sep_no_cls = is_sep .clone ()
171+ is_sep_no_cls [:, 0 ] = False
172+
173+ # first_sep: first SEP after CLS (end of task)
174+ # Use argmax on the mask — returns first True position
175+ has_sep = is_sep_no_cls .any (dim = 1 )
176+ first_sep = is_sep_no_cls .float ().argmax (dim = 1 ) # [batch]
177+
178+ # Build a segment ID for each token: which line does it belong to?
179+ # Tokens before first_sep+1 get segment -1 (task region, excluded)
180+ # LINE_SEP tokens increment the segment counter
181+ # Tokens at/after final SEP get segment -1
182+
183+ # Create position indices [batch, seq_len]
184+ pos = torch .arange (seq_len , device = device ).unsqueeze (0 ).expand (batch_size , - 1 )
185+
186+ # Mask for tokens in the lines region (after first SEP, before padding/final SEP)
187+ # We need to find the last SEP per sample
188+ # Flip and argmax to find last SEP
189+ is_sep_flipped = is_sep .flip (dims = [1 ])
190+ last_sep_from_end = is_sep_flipped .float ().argmax (dim = 1 ) # [batch]
191+ final_sep = seq_len - 1 - last_sep_from_end # [batch]
192+
193+ # Lines region: first_sep+1 <= pos < final_sep
194+ in_lines = (pos > first_sep .unsqueeze (1 )) & (pos < final_sep .unsqueeze (1 ))
195+ in_lines = in_lines & has_sep .unsqueeze (1 )
196+
197+ # Compute segment IDs via cumsum of LINE_SEP tokens in the lines region
198+ # Each LINE_SEP increments the line counter
199+ line_sep_in_region = is_line_sep & in_lines
200+ segment_ids = line_sep_in_region .long ().cumsum (dim = 1 ) # [batch, seq_len]
201+
202+ # Exclude tokens outside lines region and LINE_SEP tokens themselves
203+ valid_token = in_lines & ~ is_line_sep & ~ is_sep
204+ # Also exclude the SEP tokens that bound the region
205+ valid_token = valid_token & (pos != first_sep .unsqueeze (1 ))
206+
207+ # Number of lines per sample
208+ n_lines_per_sample = segment_ids .max (dim = 1 ).values + 1 # [batch]
209+ n_lines_per_sample = n_lines_per_sample .clamp (min = 0 )
210+ # For samples with no valid tokens, set to 0
211+ n_lines_per_sample [~ has_sep ] = 0
212+ max_lines = int (n_lines_per_sample .max ().item ())
213+ if max_lines == 0 :
214+ max_lines = 1
165215
166- for b in range (batch_size ):
167- ids = input_ids [b ].tolist ()
216+ # Use scatter_add to sum hidden states per (batch, segment)
217+ # Flatten to [batch * max_lines] buckets
218+ flat_idx = (
219+ torch .arange (batch_size , device = device ).unsqueeze (1 ) * max_lines
220+ + segment_ids
221+ ) # [batch, seq_len]
168222
169- # Find first SEP (end of task) — lines start after it
170- first_sep = - 1
171- for i , t in enumerate (ids ):
172- if i > 0 and t == sep_token_id :
173- first_sep = i
174- break
175- if first_sep < 0 :
176- all_pooled .append ([])
177- continue
178-
179- # Collect line boundaries: segments between LINE_SEP tokens
180- # Lines region: first_sep+1 ... final_sep-1
181- # Find final SEP (end of lines) — last non-pad token
182- final_sep = seq_len - 1
183- for i in range (seq_len - 1 , first_sep , - 1 ):
184- if ids [i ] == sep_token_id :
185- final_sep = i
186- break
223+ # Zero out invalid positions
224+ flat_idx = flat_idx * valid_token .long () # invalid -> bucket 0 (will be masked)
187225
188- # Collect LINE_SEP positions within lines region
189- sep_positions = []
190- for i in range (first_sep + 1 , final_sep ):
191- if ids [i ] == line_sep_id :
192- sep_positions .append (i )
193-
194- # Build line segments
195- boundaries = [first_sep + 1 ] + sep_positions + [final_sep ]
196- line_vectors : list [torch .Tensor ] = []
197-
198- for i in range (len (boundaries ) - 1 ):
199- start = boundaries [i ]
200- end = boundaries [i + 1 ]
201- # Skip the LINE_SEP token itself
202- if i > 0 :
203- start += 1
204- if start >= end :
205- line_vectors .append (torch .zeros (hidden , device = device ))
206- continue
207- line_vectors .append (hidden_states [b , start :end ].mean (dim = 0 ))
226+ # Sum hidden states into buckets
227+ pooled_flat = torch .zeros (batch_size * max_lines , hidden , device = device )
228+ counts_flat = torch .zeros (batch_size * max_lines , device = device )
208229
209- all_pooled .append (line_vectors )
210- max_lines = max (max_lines , len (line_vectors ))
230+ # Expand flat_idx for hidden dim
231+ flat_idx_expanded = flat_idx .view (- 1 ).unsqueeze (1 ).expand (- 1 , hidden )
232+ valid_flat = valid_token .view (- 1 )
211233
212- if max_lines == 0 :
213- max_lines = 1
234+ hidden_flat = hidden_states .view (- 1 , hidden )
214235
215- # Pad to [batch, max_lines, hidden]
216- pooled = torch .zeros (batch_size , max_lines , hidden , device = device )
217- line_mask = torch .zeros (batch_size , max_lines , dtype = torch .bool , device = device )
236+ # Only scatter valid tokens
237+ valid_hidden = hidden_flat [valid_flat ]
238+ valid_idx = flat_idx_expanded [valid_flat ]
239+
240+ pooled_flat .scatter_add_ (0 , valid_idx , valid_hidden )
241+ counts_flat .scatter_add_ (
242+ 0 ,
243+ flat_idx .view (- 1 )[valid_flat ],
244+ torch .ones (valid_flat .sum (), device = device ),
245+ )
218246
219- for b , vectors in enumerate (all_pooled ):
220- for i , vec in enumerate (vectors ):
221- pooled [b , i ] = vec
222- line_mask [b , i ] = True
247+ # Mean pool: divide by counts
248+ counts_flat = counts_flat .clamp (min = 1 )
249+ pooled_flat = pooled_flat / counts_flat .unsqueeze (1 )
250+
251+ # Reshape to [batch, max_lines, hidden]
252+ pooled = pooled_flat .view (batch_size , max_lines , hidden )
253+
254+ # Line mask: True where we have actual lines
255+ line_mask = torch .zeros (batch_size , max_lines , dtype = torch .bool , device = device )
256+ for b in range (batch_size ):
257+ n = int (n_lines_per_sample [b ].item ())
258+ if n > 0 :
259+ line_mask [b , :n ] = True
223260
224261 return pooled , line_mask
225262
0 commit comments