Skip to content

Commit 6c0ce0b

Browse files
committed
refactor(aero_realtime): use audio realtime stream
1 parent 761d6b7 commit 6c0ce0b

5 files changed

Lines changed: 93 additions & 81 deletions

File tree

examples/aero_realtime/example_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# AeroRealtime Training Configuration
22
# Trains the AeroRealtime model on LLaVA-Video-178K data (normal video QA mode).
33
#
4-
# The dual-stream additive design is active: during video regions, the model
5-
# receives additive vision+text embeddings and learns to stay silent (rt_pad)
6-
# until spoken to (rt_speak boundary at delay_seconds).
4+
# The realtime text stream is conditioned on audio positions only. Video
5+
# placeholders receive pure vision features, while audio placeholders carry
6+
# rt_pad / rt_speak / realtime text tokens along the audio timeline.
77
#
88
# Audio is auto-extracted from video files by the dataset.
99
#

src/lmms_engine/datasets/processor/aero_realtime_processor.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
video playback).
2222
2323
The processor builds ``text_stream_ids`` with the delay mechanism:
24-
- ``<|rt_start|>`` at position 0 of the video region
25-
- ``<|rt_pad|>`` for silence positions before the delay boundary
24+
- ``<|rt_start|>`` at the first audio position
25+
- ``<|rt_pad|>`` for audio silence positions before the delay boundary
2626
- ``<|rt_speak|>`` at the delay boundary
2727
- After ``<|rt_speak|>``: ``<|rt_pad|>`` for normal QA, or actual text tokens
28-
at the appropriate temporal positions for realtime data
28+
at the appropriate audio positions for realtime data
2929
"""
3030

3131
from typing import Dict, List, Optional
@@ -50,9 +50,9 @@ class AeroRealtimeDataProcessor(Qwen3_VLDataProcessor):
5050
"""Data processor for AeroRealtime training.
5151
5252
Builds ``input_ids``, ``text_stream_ids``, and ``labels`` for the
53-
dual-stream additive training design. Handles:
54-
- Normal video QA: video region filled with ``<|rt_pad|>`` after delay
55-
- Realtime training: text tokens placed at temporal positions in video region
53+
dual-stream training design. Handles:
54+
- Normal video QA: audio timeline filled with ``<|rt_pad|>`` after delay
55+
- Realtime training: text tokens placed at temporal positions on audio tokens
5656
- Image-only: standard scatter (no text_stream_ids)
5757
- Audio extraction from video for audio-vision fusion
5858
"""
@@ -334,14 +334,13 @@ def _build_normal_qa_ids_and_labels(
334334
"""Build input_ids, text_stream_ids, and labels from HF messages.
335335
336336
For normal video QA the text_stream_ids only differ from input_ids
337-
in the multimodal pad regions:
338-
- all ``<|video_pad|>`` and ``<|audio_pad|>`` slots → ``<|rt_pad|>``
339-
- first chunk's first ``<|video_pad|>`` → ``<|rt_start|>``
340-
- speak chunk's first ``<|audio_pad|>`` → ``<|rt_speak|>``
341-
342-
Envelope boundary tokens (timestamps, vision_start/end,
343-
audio_start/end) keep their original ids in text_stream_ids so the
344-
LM sees the same special tokens it would in input_ids.
337+
on audio pad positions:
338+
- all ``<|audio_pad|>`` slots -> ``<|rt_pad|>``
339+
- first ``<|audio_pad|>`` -> ``<|rt_start|>``
340+
- delayed ``<|audio_pad|>`` -> ``<|rt_speak|>``
341+
342+
Video placeholders and envelope boundary tokens keep their original
343+
ids; vision features replace video placeholder embeddings in the model.
345344
"""
346345
results = self.get_qwen_template_labels(
347346
hf_messages,
@@ -364,7 +363,7 @@ def _build_normal_qa_ids_and_labels(
364363
text_stream_id = list(input_id) # start as a copy of input_ids
365364

366365
if has_video and has_audio:
367-
# video + audio: per-chunk envelope filler
366+
# video + audio: only audio pads carry realtime stream tokens
368367
self.processor._fill_text_stream_video_audio(
369368
stream=text_stream_id,
370369
video_grid_thw=video_grid_thw,
@@ -374,7 +373,6 @@ def _build_normal_qa_ids_and_labels(
374373
vision_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token),
375374
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
376375
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
377-
video_pad_id=self.video_token_id,
378376
audio_pad_id=self.audio_token_id,
379377
rt_start_id=self.rt_start_id,
380378
rt_pad_id=self.rt_pad_id,
@@ -460,7 +458,6 @@ def _build_realtime_ids_and_labels(
460458
vision_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token),
461459
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
462460
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
463-
video_pad_id=self.video_token_id,
464461
audio_pad_id=self.audio_token_id,
465462
rt_start_id=self.rt_start_id,
466463
rt_pad_id=self.rt_pad_id,

src/lmms_engine/models/aero_realtime/aero_realtime_liger.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def aero_realtime_lce_forward(
7070
):
7171
"""RMPad-aware forward for AeroRealtime with LigerCE loss.
7272
73-
Same pipeline as the original forward (embed → vision → audio → add at
74-
video/audio token positions independently — time alignment comes from
75-
the per-chunk envelope token order, not from feature-level fusion).
73+
Same pipeline as the original forward (embed → scatter vision → add audio
74+
on audio token positions — realtime conditioning lives on the audio
75+
timeline).
7676
Adds:
7777
- Proper mrope position_ids via ``qwen3_vl_get_rope_index``
7878
- Unpadding of inputs_embeds/position_ids/labels before the language model
@@ -126,9 +126,9 @@ def aero_realtime_lce_forward(
126126
else:
127127
audio_features_flat = audio_features.reshape(-1, audio_features.shape[-1])
128128

129-
# ---- 5. Add vision/audio features (independent paths) ----
129+
# ---- 5. Scatter video features and add audio features ----
130130

131-
# 5a. Add video features at video_token_index positions
131+
# 5a. Scatter video features at video_token_index positions
132132
if video_features is not None:
133133
video_mask = original_input_ids == self.config.video_token_index
134134
n_video_tokens = video_mask.sum().item()
@@ -137,12 +137,11 @@ def aero_realtime_lce_forward(
137137
raise ValueError(
138138
f"Video token count ({n_video_tokens}) does not match " f"video feature count ({n_video_features})."
139139
)
140-
video_mask_flat = video_mask.reshape(-1)
141-
inputs_embeds_flat = inputs_embeds.reshape(-1, inputs_embeds.shape[-1])
142-
inputs_embeds_flat[video_mask_flat] = inputs_embeds_flat[video_mask_flat] + video_features.to(
143-
inputs_embeds.dtype
140+
video_mask_expanded = video_mask.unsqueeze(-1).expand_as(inputs_embeds)
141+
inputs_embeds = inputs_embeds.masked_scatter(
142+
video_mask_expanded,
143+
video_features.to(inputs_embeds.dtype),
144144
)
145-
inputs_embeds = inputs_embeds_flat.reshape(inputs_embeds.shape)
146145

147146
# 5b. Add audio features at audio_token_index positions
148147
if audio_features_flat is not None:

src/lmms_engine/models/aero_realtime/modeling_aero_realtime.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@ def forward(
518518
Audio and video are kept as **separate** token streams in the input
519519
sequence (per-chunk envelope ``[VS][AS][video_pad×S][audio_pad×N]
520520
[AE][VE]``) so time alignment is expressed entirely through token
521-
order and RoPE. Each modality's features are simply added at the
522-
positions of their corresponding placeholder tokens.
521+
order and RoPE. Vision features replace vision placeholders; audio
522+
features are added to the realtime text stream on audio placeholders.
523523
524524
Modality combinations:
525525
@@ -528,22 +528,22 @@ def forward(
528528
positions. ``text_stream_ids`` is not used.
529529
530530
**Video mode** (``pixel_values_videos`` + ``video_grid_thw``):
531-
Video features are **added** to embeddings at
531+
Video features are scattered (replace) at
532532
``video_token_index`` positions.
533533
534534
**Audio mode** (``input_features``):
535535
Audio features are **added** to embeddings at
536536
``audio_token_index`` positions.
537537
538-
**Video + Audio**: both add paths run independently on their own
539-
token positions. ``text_stream_ids`` carries the realtime markers
540-
(``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``) at the
541-
envelope positions and is used for the input embedding lookup.
538+
**Video + Audio**: video placeholders receive pure vision features.
539+
``text_stream_ids`` carries realtime markers (``<|rt_start|>``,
540+
``<|rt_pad|>``, ``<|rt_speak|>``) only at audio positions, where audio
541+
features are added to the realtime text embeddings.
542542
543543
Pipeline:
544544
1. Embed ``text_stream_ids`` (if provided) or ``input_ids``.
545545
2. Image features → scatter at ``image_token_index``.
546-
3. Video features → add at ``video_token_index``.
546+
3. Video features → scatter at ``video_token_index``.
547547
4. Audio features → add at ``audio_token_index``.
548548
5. Forward through the language model.
549549
@@ -552,9 +552,9 @@ def forward(
552552
Shape ``[batch_size, seq_len]``. Used to determine the
553553
position masks for image/video/audio features.
554554
text_stream_ids: Parallel text-stream token ids.
555-
Shape ``[batch_size, seq_len]``. At video/audio positions
556-
contains ``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``,
557-
or actual text tokens; mirrors ``input_ids`` elsewhere.
555+
Shape ``[batch_size, seq_len]``. At audio positions contains
556+
``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``, or actual
557+
text tokens; mirrors ``input_ids`` elsewhere.
558558
If not provided, falls back to ``input_ids``.
559559
pixel_values: Image pixel values (flat across batch).
560560
image_grid_thw: Grid info per image. ``[num_images, 3]``.
@@ -569,9 +569,8 @@ def forward(
569569
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
570570

571571
# Determine which token ids to use for embedding
572-
# text_stream_ids provides the actual text tokens (including rt_start,
573-
# rt_pad, rt_speak) at video/audio positions. input_ids is only used
574-
# for determining the position mask.
572+
# text_stream_ids provides the realtime text tokens at audio positions.
573+
# input_ids is used for determining modality placeholder masks.
575574
embed_ids = text_stream_ids if text_stream_ids is not None else input_ids
576575

577576
# ----------------------------------------------------------------
@@ -602,7 +601,7 @@ def forward(
602601
)
603602

604603
# ----------------------------------------------------------------
605-
# 3. Video features — extract (additive fusion happens below)
604+
# 3. Video features — extract (scatter happens below)
606605
# ----------------------------------------------------------------
607606
video_features = None
608607
if pixel_values_videos is not None:
@@ -625,12 +624,11 @@ def forward(
625624
audio_features_flat = audio_features.reshape(-1, audio_features.shape[-1])
626625

627626
# ----------------------------------------------------------------
628-
# 5. Add video / audio features to text-stream embeddings
629-
# (independent paths — time alignment is handled by token order
630-
# within each per-chunk envelope, not by feature-level fusion)
627+
# 5. Scatter video features and add audio features to text-stream embeddings.
628+
# Realtime text conditioning lives only on the audio timeline.
631629
# ----------------------------------------------------------------
632630

633-
# 5a. Video features -> add at video_token_index positions
631+
# 5a. Video features -> scatter at video_token_index positions
634632
if video_features is not None:
635633
video_mask = input_ids == self.config.video_token_index
636634
n_video_tokens = video_mask.sum().item()
@@ -640,12 +638,11 @@ def forward(
640638
f"Video token count ({n_video_tokens}) does not match " f"video feature count ({n_video_features})."
641639
)
642640

643-
video_mask_flat = video_mask.reshape(-1)
644-
inputs_embeds_flat = inputs_embeds.reshape(-1, inputs_embeds.shape[-1])
645-
inputs_embeds_flat[video_mask_flat] = inputs_embeds_flat[video_mask_flat] + video_features.to(
646-
inputs_embeds.dtype
641+
video_mask_expanded = video_mask.unsqueeze(-1).expand_as(inputs_embeds)
642+
inputs_embeds = inputs_embeds.masked_scatter(
643+
video_mask_expanded,
644+
video_features.to(inputs_embeds.dtype),
647645
)
648-
inputs_embeds = inputs_embeds_flat.reshape(inputs_embeds.shape)
649646

650647
# 5b. Audio features -> add at audio_token_index positions
651648
if audio_features_flat is not None:

src/lmms_engine/models/aero_realtime/processing_aero_realtime.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class AeroRealtimeProcessor(ProcessorMixin):
7474
- Text tokenization with placeholder expansion for images, videos, and
7575
audio tokens.
7676
- Construction of ``text_stream_ids`` carrying the realtime markers
77-
(``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``) when audio
78-
is present (streaming mode).
77+
(``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``) on audio
78+
positions when audio is present (streaming mode).
7979
8080
Args:
8181
image_processor: Image processor instance (e.g. ``Qwen2VLImageProcessor``).
@@ -630,18 +630,17 @@ def _build_text_stream_ids(
630630
) -> Union[list, torch.Tensor]:
631631
"""Build ``text_stream_ids`` for the realtime dual-stream design.
632632
633-
``text_stream_ids`` mirrors ``input_ids`` everywhere except inside
634-
the multimodal regions, where it carries the realtime-text-stream
633+
``text_stream_ids`` mirrors ``input_ids`` everywhere except audio
634+
placeholder positions, where it carries the realtime-text-stream
635635
markers (``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``).
636636
637637
Streaming mode is gated on the presence of audio. Two layouts:
638638
639639
- **video + audio (interleave)**: input contains per-chunk envelopes
640-
``[VS][AS][video_pad×S][audio_pad×N][AE][VE]``. All envelope and
641-
pad positions become ``<|rt_pad|>`` (model stays silent over
642-
vision); the very first ``video_pad`` of the first chunk becomes
643-
``<|rt_start|>``; the first audio_pad of the first chunk whose
644-
start time ``>= delay_seconds`` becomes ``<|rt_speak|>``.
640+
``[VS][AS][video_pad×S][audio_pad×N][AE][VE]``. Video placeholders
641+
stay as ``<|video_pad|>``; only audio placeholders carry
642+
``<|rt_pad|>``, with the first audio placeholder as ``<|rt_start|>``
643+
and the first delayed audio placeholder as ``<|rt_speak|>``.
645644
- **audio-only**: ``[AS][audio_pad×N][AE]``. First ``audio_pad``
646645
becomes ``<|rt_start|>``; the first audio_pad whose timestamp
647646
``>= delay_seconds`` becomes ``<|rt_speak|>``.
@@ -656,7 +655,6 @@ def _build_text_stream_ids(
656655
rt_pad_id = self.tokenizer.convert_tokens_to_ids(self.rt_pad_token)
657656
rt_speak_id = self.tokenizer.convert_tokens_to_ids(self.rt_speak_token)
658657

659-
video_pad_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
660658
audio_pad_id = self.tokenizer.convert_tokens_to_ids(self.audio_token)
661659
vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start_token)
662660
vision_end_id = self.tokenizer.convert_tokens_to_ids(self.vision_end_token)
@@ -681,7 +679,6 @@ def _build_text_stream_ids(
681679
vision_end_id=vision_end_id,
682680
audio_start_id=audio_start_id,
683681
audio_end_id=audio_end_id,
684-
video_pad_id=video_pad_id,
685682
audio_pad_id=audio_pad_id,
686683
rt_start_id=rt_start_id,
687684
rt_pad_id=rt_pad_id,
@@ -721,19 +718,20 @@ def _fill_text_stream_video_audio(
721718
vision_end_id: int,
722719
audio_start_id: int,
723720
audio_end_id: int,
724-
video_pad_id: int,
725721
audio_pad_id: int,
726722
rt_start_id: int,
727723
rt_pad_id: int,
728724
rt_speak_id: int,
729725
) -> None:
730726
"""In-place fill of text_stream for the interleaved video+audio mode.
731727
732-
Only ``<|video_pad|>`` and ``<|audio_pad|>`` positions (which receive
733-
added vision / audio features in the model) are overwritten:
734-
- all video_pad / audio_pad slots → ``<|rt_pad|>``
735-
- first chunk's first video_pad → ``<|rt_start|>``
736-
- speak chunk's first audio_pad → ``<|rt_speak|>``
728+
Only ``<|audio_pad|>`` positions are overwritten:
729+
- all audio_pad slots -> ``<|rt_pad|>``
730+
- first audio_pad -> ``<|rt_start|>``
731+
- first delayed audio_pad -> ``<|rt_speak|>``
732+
733+
``<|video_pad|>`` positions keep their original ids because video
734+
features replace those embeddings in the model.
737735
738736
Envelope boundary tokens (``<t.t seconds>``, ``<|vision_start|>``,
739737
``<|audio_start|>``, ``<|audio_end|>``, ``<|vision_end|>``) keep
@@ -788,22 +786,43 @@ def _fill_text_stream_video_audio(
788786
# as_+spatial+1 .. ae-1: <|audio_pad|> × N_t
789787
# ae: <|audio_end|>
790788
# ve: <|vision_end|>
791-
video_pad_start = as_ + 1
792-
video_pad_end = as_ + spatial # inclusive
793789
audio_pad_start = as_ + spatial + 1
794790
audio_pad_end = ae - 1 # inclusive
795791

796-
for k in range(video_pad_start, video_pad_end + 1):
797-
stream[k] = rt_pad_id
798792
for k in range(audio_pad_start, audio_pad_end + 1):
799793
stream[k] = rt_pad_id
800794

801-
# rt_start: first chunk's first video_pad
802-
if c_idx == 0 and video_pad_start <= video_pad_end:
803-
stream[video_pad_start] = rt_start_id
804-
# rt_speak: speak chunk's first audio_pad
805-
if c_idx == speak_chunk and audio_pad_start <= audio_pad_end:
806-
stream[audio_pad_start] = rt_speak_id
795+
audio_ranges = []
796+
for c_idx, ((_, as_, ae, _), (_, _, _, spatial)) in enumerate(zip(envelopes, chunks)):
797+
audio_pad_start = as_ + spatial + 1
798+
audio_pad_end = ae - 1
799+
if audio_pad_start <= audio_pad_end:
800+
audio_ranges.append((c_idx, audio_pad_start, audio_pad_end))
801+
802+
if not audio_ranges:
803+
return
804+
805+
first_audio_pos = audio_ranges[0][1]
806+
speak_pos = None
807+
for c_idx, audio_pad_start, _ in audio_ranges:
808+
if c_idx >= speak_chunk:
809+
speak_pos = audio_pad_start
810+
break
811+
if speak_pos is None:
812+
speak_pos = audio_ranges[-1][1]
813+
814+
if speak_pos == first_audio_pos:
815+
for _, audio_pad_start, audio_pad_end in audio_ranges:
816+
if audio_pad_start <= first_audio_pos < audio_pad_end:
817+
speak_pos = first_audio_pos + 1
818+
break
819+
if audio_pad_start > first_audio_pos:
820+
speak_pos = audio_pad_start
821+
break
822+
823+
stream[first_audio_pos] = rt_start_id
824+
if speak_pos != first_audio_pos:
825+
stream[speak_pos] = rt_speak_id
807826

808827
def _fill_text_stream_audio_only(
809828
self,

0 commit comments

Comments
 (0)