@@ -54,6 +54,7 @@ def _get_audio_features_rmpad(self, input_features: torch.Tensor, audio_attentio
5454 audio_hidden_states = audio_outputs .last_hidden_state
5555
5656 df = self .config .downsample_factor
57+ H = self .config .audio_hidden_size
5758 audio_output_lengths = None
5859 if audio_attention_mask is not None :
5960 audio_output_lengths = audio_attention_mask .sum (- 1 ) // df
@@ -62,27 +63,45 @@ def _get_audio_features_rmpad(self, input_features: torch.Tensor, audio_attentio
6263 if audio_attention_mask is None :
6364 raise ValueError ("Packed audio hidden states require audio_attention_mask." )
6465
65- chunks = []
66- offset = 0
67- for length in audio_attention_mask .sum (- 1 ).tolist ():
68- usable_len = (length // df ) * df
69- if usable_len > 0 :
70- chunk = audio_hidden_states [offset : offset + usable_len ]
71- chunks .append (chunk .reshape (- 1 , self .config .audio_hidden_size * df ))
72- offset += length
73-
74- if chunks :
75- audio_hidden_states = torch .cat (chunks , dim = 0 )
66+ B , max_T = audio_attention_mask .shape
67+ total_valid = audio_hidden_states .shape [0 ]
68+
69+ # Fast path: every row of audio_attention_mask is fully valid AND
70+ # max_T is divisible by df. This is the dominant case under chunked
71+ # streaming training (each chunk produces exactly df encoder frames),
72+ # and lets us skip the per-segment python loop + cat in the slow path.
73+ if max_T % df == 0 and total_valid == B * max_T :
74+ audio_hidden_states = audio_hidden_states .reshape (- 1 , H * df )
7675 else :
77- audio_hidden_states = audio_hidden_states .new_empty ((0 , self .config .audio_hidden_size * df ))
76+ # Slow path: ragged segments. Build a fully-GPU gather index that
77+ # selects only the `usable_len = (length // df) * df` rows from
78+ # each segment, then reshape in one shot. Avoids the python loop
79+ # + per-chunk slice + cat over potentially thousands of segments.
80+ lengths = audio_attention_mask .sum (- 1 )
81+ usable_lens = (lengths // df ) * df
82+ total_usable = int (usable_lens .sum ().item ())
83+ if total_usable == 0 :
84+ audio_hidden_states = audio_hidden_states .new_empty ((0 , H * df ))
85+ else :
86+ in_starts = torch .cumsum (lengths , dim = 0 ) - lengths
87+ out_starts = torch .cumsum (usable_lens , dim = 0 ) - usable_lens
88+ flat = torch .arange (total_usable , device = lengths .device )
89+ seg = torch .searchsorted (
90+ out_starts [1 :].contiguous () if B > 1 else out_starts .new_zeros (0 ),
91+ flat ,
92+ right = True ,
93+ )
94+ gather_idx = flat + (in_starts - out_starts )[seg ]
95+ audio_hidden_states = audio_hidden_states .index_select (0 , gather_idx )
96+ audio_hidden_states = audio_hidden_states .reshape (- 1 , H * df )
7897 else :
7998 seq_len = audio_hidden_states .shape [1 ]
8099 usable_len = (seq_len // df ) * df
81100 audio_hidden_states = audio_hidden_states [:, :usable_len , :]
82101 audio_hidden_states = audio_hidden_states .reshape (
83102 audio_hidden_states .shape [0 ],
84103 - 1 ,
85- self . config . audio_hidden_size * df ,
104+ H * df ,
86105 )
87106
88107 return self .multi_modal_projector (audio_hidden_states ), audio_output_lengths
0 commit comments