Skip to content

Commit e974cbe

Browse files
committed
refactor(aero_realtime): add speech span labels
1 parent 6c0ce0b commit e974cbe

5 files changed

Lines changed: 99 additions & 148 deletions

File tree

src/lmms_engine/datasets/processor/aero_realtime_processor.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@
2020
(assistant text segments are placed at specific temporal positions during
2121
video playback).
2222
23-
The processor builds ``text_stream_ids`` with the delay mechanism:
24-
- ``<|rt_start|>`` at the first audio position
25-
- ``<|rt_pad|>`` for audio silence positions before the delay boundary
26-
- ``<|rt_speak|>`` at the delay boundary
27-
- After ``<|rt_speak|>``: ``<|rt_pad|>`` for normal QA, or actual text tokens
28-
at the appropriate audio positions for realtime data
23+
The processor builds ``text_stream_ids`` on the audio timeline. ``<|rt_pad|>``
24+
is silence context only; labels supervise ``<|rt_speak|>``, speech span
25+
boundaries (``<|rt_start|>`` / ``<|rt_end|>``), and speech text tokens.
2926
"""
3027

3128
from typing import Dict, List, Optional
@@ -50,9 +47,9 @@ class AeroRealtimeDataProcessor(Qwen3_VLDataProcessor):
5047
"""Data processor for AeroRealtime training.
5148
5249
Builds ``input_ids``, ``text_stream_ids``, and ``labels`` for the
53-
dual-stream training design. Handles:
54-
- Normal video QA: audio timeline filled with ``<|rt_pad|>`` after delay
55-
- Realtime training: text tokens placed at temporal positions on audio tokens
50+
realtime audio-stream training design. Handles:
51+
- Normal video QA: audio timeline filled with ``<|rt_pad|>`` context
52+
- Realtime training: boundary and text labels on audio tokens
5653
- Image-only: standard scatter (no text_stream_ids)
5754
- Audio extraction from video for audio-vision fusion
5855
"""
@@ -135,6 +132,10 @@ def rt_pad_id(self):
135132
def rt_speak_id(self):
136133
return self.tokenizer.convert_tokens_to_ids(self.processor.rt_speak_token)
137134

135+
@property
136+
def rt_end_id(self):
137+
return self.tokenizer.convert_tokens_to_ids(self.processor.rt_end_token)
138+
138139
# ------------------------------------------------------------------
139140
# Main process entry point
140141
# ------------------------------------------------------------------
@@ -334,10 +335,9 @@ def _build_normal_qa_ids_and_labels(
334335
"""Build input_ids, text_stream_ids, and labels from HF messages.
335336
336337
For normal video QA the text_stream_ids only differ from input_ids
337-
on audio pad positions:
338-
- all ``<|audio_pad|>`` slots -> ``<|rt_pad|>``
339-
- first ``<|audio_pad|>`` -> ``<|rt_start|>``
340-
- delayed ``<|audio_pad|>`` -> ``<|rt_speak|>``
338+
on audio pad positions, where all ``<|audio_pad|>`` slots become
339+
``<|rt_pad|>`` context. Normal QA keeps standard assistant labels;
340+
realtime span labels are built by ``_build_realtime_ids_and_labels``.
341341
342342
Video placeholders and envelope boundary tokens keep their original
343343
ids; vision features replace video placeholder embeddings in the model.
@@ -363,20 +363,15 @@ def _build_normal_qa_ids_and_labels(
363363
text_stream_id = list(input_id) # start as a copy of input_ids
364364

365365
if has_video and has_audio:
366-
# video + audio: only audio pads carry realtime stream tokens
366+
# video + audio: only audio pads carry realtime stream context
367367
self.processor._fill_text_stream_video_audio(
368368
stream=text_stream_id,
369369
video_grid_thw=video_grid_thw,
370370
video_metadata=video_metadata,
371371
temporal_patch_size=getattr(self.processor.video_processor, "temporal_patch_size", 2),
372-
vision_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_start_token),
373-
vision_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token),
374372
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
375373
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
376-
audio_pad_id=self.audio_token_id,
377-
rt_start_id=self.rt_start_id,
378374
rt_pad_id=self.rt_pad_id,
379-
rt_speak_id=self.rt_speak_id,
380375
)
381376
elif has_audio:
382377
# audio-only: single envelope per audio sample
@@ -454,14 +449,9 @@ def _build_realtime_ids_and_labels(
454449
video_grid_thw=video_grid_thw,
455450
video_metadata=video_metadata,
456451
temporal_patch_size=getattr(self.processor.video_processor, "temporal_patch_size", 2),
457-
vision_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_start_token),
458-
vision_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token),
459452
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
460453
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
461-
audio_pad_id=self.audio_token_id,
462-
rt_start_id=self.rt_start_id,
463454
rt_pad_id=self.rt_pad_id,
464-
rt_speak_id=self.rt_speak_id,
465455
)
466456

467457
audio_positions = [idx for idx, token_id in enumerate(input_id) if token_id == self.audio_token_id]
@@ -474,29 +464,36 @@ def _build_realtime_ids_and_labels(
474464
raise ValueError(f"Audio position/time mismatch: {len(audio_positions)} != {len(audio_times)}")
475465

476466
delay = getattr(self.processor, "delay_seconds", 2.0)
477-
for pos, t_sec in zip(audio_positions, audio_times):
478-
if t_sec >= delay:
479-
target[pos] = self.rt_pad_id
467+
speak_audio_idx = self._first_index_at_or_after(audio_times, delay)
468+
if speak_audio_idx < len(audio_positions):
469+
speak_pos = audio_positions[speak_audio_idx]
470+
text_stream_id[speak_pos] = self.rt_speak_id
480471

481472
assistant_segments = sorted(
482473
[seg for seg in (realtime_segments or []) if seg.get("role") == "assistant" and seg.get("text")],
483474
key=lambda item: float(item["time"]),
484475
)
485-
event_times = sorted(float(seg["time"]) for seg in (realtime_segments or []))
476+
occupied_audio_indices = {speak_audio_idx} if speak_audio_idx < len(audio_positions) else set()
486477
for segment in assistant_segments:
487478
start_time = float(segment["time"])
488-
end_time = self._next_time_after(event_times, start_time)
489479
start_audio_idx = self._first_index_at_or_after(audio_times, start_time)
490-
end_audio_idx = (
491-
self._first_index_at_or_after(audio_times, end_time) if end_time is not None else len(audio_positions)
480+
if speak_audio_idx < len(audio_positions):
481+
start_audio_idx = max(start_audio_idx, speak_audio_idx + 1)
482+
available_indices = self._next_available_indices(
483+
start=start_audio_idx,
484+
count=len(audio_positions),
485+
limit=len(audio_positions),
486+
occupied=occupied_audio_indices,
492487
)
493-
if start_audio_idx < end_audio_idx and text_stream_id[audio_positions[start_audio_idx]] == self.rt_speak_id:
494-
start_audio_idx += 1
495-
token_ids = self._encode_realtime_text(segment["text"])
496-
for offset, token_id in enumerate(token_ids[: max(0, end_audio_idx - start_audio_idx)]):
497-
pos = audio_positions[start_audio_idx + offset]
488+
if len(available_indices) < 2:
489+
continue
490+
text_token_budget = len(available_indices) - 2
491+
token_ids = [self.rt_start_id] + self._encode_realtime_text(segment["text"])[:text_token_budget] + [self.rt_end_id]
492+
for audio_idx, token_id in zip(available_indices, token_ids):
493+
pos = audio_positions[audio_idx]
498494
text_stream_id[pos] = token_id
499495
target[pos] = token_id
496+
occupied_audio_indices.add(audio_idx)
500497

501498
input_tensor = torch.tensor(input_id, dtype=torch.long)
502499
text_stream_tensor = torch.tensor(text_stream_id, dtype=torch.long)
@@ -579,11 +576,14 @@ def _first_index_at_or_after(values: List[float], target: float) -> int:
579576
return len(values)
580577

581578
@staticmethod
582-
def _next_time_after(values: List[float], target: float) -> Optional[float]:
583-
for value in values:
584-
if value > target:
585-
return value
586-
return None
579+
def _next_available_indices(start: int, count: int, limit: int, occupied: set) -> List[int]:
580+
indices = []
581+
idx = start
582+
while idx < limit and len(indices) < count:
583+
if idx not in occupied:
584+
indices.append(idx)
585+
idx += 1
586+
return indices
587587

588588
def get_qwen_template_labels(
589589
self,
@@ -682,12 +682,10 @@ def _expand_encode_id_video_tokens(
682682
683683
- Without audio: per-frame Qwen3VL legacy expansion (delegated to
684684
parent).
685-
- With audio: per-chunk envelope expansion matching the model
686-
processor's path 5b layout::
685+
- With audio: per-chunk separated vision/audio envelopes::
687686
688-
<t.t seconds><|vision_start|><|audio_start|>
689-
<|video_pad|>×spatial <|audio_pad|>×N_t
690-
<|audio_end|><|vision_end|>
687+
<t.t seconds><|vision_start|><|video_pad|>×spatial<|vision_end|>
688+
<|audio_start|><|audio_pad|>×N_t<|audio_end|>
691689
"""
692690
if audio_per_chunk_per_video is None:
693691
return super()._expand_encode_id_video_tokens(
@@ -740,11 +738,11 @@ def _expand_encode_id_video_tokens(
740738
n_audio_t = audio_per_chunk[t]
741739
expanded_encode_id.extend(timestamp_token_ids)
742740
expanded_encode_id.append(vision_start_id)
743-
expanded_encode_id.append(audio_start_id)
744741
expanded_encode_id.extend([self.video_token_id] * spatial)
742+
expanded_encode_id.append(vision_end_id)
743+
expanded_encode_id.append(audio_start_id)
745744
expanded_encode_id.extend([self.audio_token_id] * n_audio_t)
746745
expanded_encode_id.append(audio_end_id)
747-
expanded_encode_id.append(vision_end_id)
748746

749747
prev = pos + 2 # skip past original <|vision_end|>
750748

src/lmms_engine/models/aero_realtime/configuration_aero_realtime.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class AeroRealtimeConfig(PretrainedConfig):
7979
rt_speak_token_index (`int`, *optional*, defaults to `151674`):
8080
Token index for ``<|rt_speak|>`` — delay boundary marker after which
8181
the model may begin producing text.
82+
rt_end_token_index (`int`, *optional*, defaults to `151675`):
83+
Token index for ``<|rt_end|>`` — closes one realtime speech span.
8284
delay_seconds (`float`, *optional*, defaults to `2.0`):
8385
Delay in seconds before the model is allowed to speak. Converted to
8486
a number of vision tokens based on the video's temporal resolution.
@@ -122,6 +124,7 @@ def __init__(
122124
rt_start_token_index=151672,
123125
rt_pad_token_index=151673,
124126
rt_speak_token_index=151674,
127+
rt_end_token_index=151675,
125128
delay_seconds=2.0,
126129
tie_word_embeddings=False,
127130
**kwargs,
@@ -138,6 +141,7 @@ def __init__(
138141
self.rt_start_token_index = rt_start_token_index
139142
self.rt_pad_token_index = rt_pad_token_index
140143
self.rt_speak_token_index = rt_speak_token_index
144+
self.rt_end_token_index = rt_end_token_index
141145
self.delay_seconds = delay_seconds
142146

143147
# Aliases expected by qwen3_vl_get_rope_index (shared RoPE helper)

src/lmms_engine/models/aero_realtime/modeling_aero_realtime.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,8 @@ def forward(
516516
"""Forward pass for AeroRealtime.
517517
518518
Audio and video are kept as **separate** token streams in the input
519-
sequence (per-chunk envelope ``[VS][AS][video_pad×S][audio_pad×N]
520-
[AE][VE]``) so time alignment is expressed entirely through token
519+
sequence (per-chunk envelope ``[VS][video_pad×S][VE][AS]
520+
[audio_pad×N][AE]``) so time alignment is expressed entirely through token
521521
order and RoPE. Vision features replace vision placeholders; audio
522522
features are added to the realtime text stream on audio placeholders.
523523
@@ -536,9 +536,9 @@ def forward(
536536
``audio_token_index`` positions.
537537
538538
**Video + Audio**: video placeholders receive pure vision features.
539-
``text_stream_ids`` carries realtime markers (``<|rt_start|>``,
540-
``<|rt_pad|>``, ``<|rt_speak|>``) only at audio positions, where audio
541-
features are added to the realtime text embeddings.
539+
``text_stream_ids`` carries realtime markers (``<|rt_speak|>``,
540+
``<|rt_start|>``, ``<|rt_end|>``, and speech text) only at audio
541+
positions, where audio features are added to the realtime text embeddings.
542542
543543
Pipeline:
544544
1. Embed ``text_stream_ids`` (if provided) or ``input_ids``.
@@ -553,8 +553,8 @@ def forward(
553553
position masks for image/video/audio features.
554554
text_stream_ids: Parallel text-stream token ids.
555555
Shape ``[batch_size, seq_len]``. At audio positions contains
556-
``<|rt_start|>``, ``<|rt_pad|>``, ``<|rt_speak|>``, or actual
557-
text tokens; mirrors ``input_ids`` elsewhere.
556+
``<|rt_pad|>``, ``<|rt_speak|>``, speech boundary tokens, or
557+
actual text tokens; mirrors ``input_ids`` elsewhere.
558558
If not provided, falls back to ``input_ids``.
559559
pixel_values: Image pixel values (flat across batch).
560560
image_grid_thw: Grid info per image. ``[num_images, 3]``.

0 commit comments

Comments
 (0)