Skip to content

Commit 95bc8cf

Browse files
committed
perf(aero_realtime): drop python-loop in audio rmpad gather
1 parent 4b7ae4e commit 95bc8cf

1 file changed

Lines changed: 32 additions & 13 deletions

File tree

src/lmms_engine/models/aero_realtime/aero_realtime_liger.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _get_audio_features_rmpad(self, input_features: torch.Tensor, audio_attentio
5454
audio_hidden_states = audio_outputs.last_hidden_state
5555

5656
df = self.config.downsample_factor
57+
H = self.config.audio_hidden_size
5758
audio_output_lengths = None
5859
if audio_attention_mask is not None:
5960
audio_output_lengths = audio_attention_mask.sum(-1) // df
@@ -62,27 +63,45 @@ def _get_audio_features_rmpad(self, input_features: torch.Tensor, audio_attentio
6263
if audio_attention_mask is None:
6364
raise ValueError("Packed audio hidden states require audio_attention_mask.")
6465

65-
chunks = []
66-
offset = 0
67-
for length in audio_attention_mask.sum(-1).tolist():
68-
usable_len = (length // df) * df
69-
if usable_len > 0:
70-
chunk = audio_hidden_states[offset : offset + usable_len]
71-
chunks.append(chunk.reshape(-1, self.config.audio_hidden_size * df))
72-
offset += length
73-
74-
if chunks:
75-
audio_hidden_states = torch.cat(chunks, dim=0)
66+
B, max_T = audio_attention_mask.shape
67+
total_valid = audio_hidden_states.shape[0]
68+
69+
# Fast path: every row of audio_attention_mask is fully valid AND
70+
# max_T is divisible by df. This is the dominant case under chunked
71+
# streaming training (each chunk produces exactly df encoder frames),
72+
# and lets us skip the per-segment python loop + cat in the slow path.
73+
if max_T % df == 0 and total_valid == B * max_T:
74+
audio_hidden_states = audio_hidden_states.reshape(-1, H * df)
7675
else:
77-
audio_hidden_states = audio_hidden_states.new_empty((0, self.config.audio_hidden_size * df))
76+
# Slow path: ragged segments. Build a fully-GPU gather index that
77+
# selects only the `usable_len = (length // df) * df` rows from
78+
# each segment, then reshape in one shot. Avoids the python loop
79+
# + per-chunk slice + cat over potentially thousands of segments.
80+
lengths = audio_attention_mask.sum(-1)
81+
usable_lens = (lengths // df) * df
82+
total_usable = int(usable_lens.sum().item())
83+
if total_usable == 0:
84+
audio_hidden_states = audio_hidden_states.new_empty((0, H * df))
85+
else:
86+
in_starts = torch.cumsum(lengths, dim=0) - lengths
87+
out_starts = torch.cumsum(usable_lens, dim=0) - usable_lens
88+
flat = torch.arange(total_usable, device=lengths.device)
89+
seg = torch.searchsorted(
90+
out_starts[1:].contiguous() if B > 1 else out_starts.new_zeros(0),
91+
flat,
92+
right=True,
93+
)
94+
gather_idx = flat + (in_starts - out_starts)[seg]
95+
audio_hidden_states = audio_hidden_states.index_select(0, gather_idx)
96+
audio_hidden_states = audio_hidden_states.reshape(-1, H * df)
7897
else:
7998
seq_len = audio_hidden_states.shape[1]
8099
usable_len = (seq_len // df) * df
81100
audio_hidden_states = audio_hidden_states[:, :usable_len, :]
82101
audio_hidden_states = audio_hidden_states.reshape(
83102
audio_hidden_states.shape[0],
84103
-1,
85-
self.config.audio_hidden_size * df,
104+
H * df,
86105
)
87106

88107
return self.multi_modal_projector(audio_hidden_states), audio_output_lengths

0 commit comments

Comments
 (0)