Skip to content

Commit 70d605b

Browse files
committed
refactor(aero_realtime): per-chunk envelope + RoPE time alignment
Replace embedding-level audio/video timestep fusion with per-chunk envelope token interleaving. Each video chunk emits <t.t seconds><|vision_start|><|audio_start|><|video_pad|>x spatial <|audio_pad|>x N_t<|audio_end|><|vision_end|>; time alignment now comes from token order, so video and audio features are added independently at their respective placeholder positions.
1 parent 3fbd6b9 commit 70d605b

6 files changed

Lines changed: 703 additions & 612 deletions

File tree

src/lmms_engine/datasets/processor/aero_realtime_processor.py

Lines changed: 160 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def process(
146146
audios: Optional[List[np.ndarray]] = None,
147147
sampling_rate: Optional[int] = None,
148148
videos=None,
149-
video_metadata=None,
150149
realtime_segments: Optional[List[Dict]] = None,
151150
system_message: str = "You are a helpful assistant",
152151
add_system_prompt=True,
@@ -162,14 +161,13 @@ def process(
162161
audios: List of audio waveforms (mono, float32, at sampling_rate).
163162
sampling_rate: Audio sampling rate.
164163
videos: List of video frames (numpy arrays, TCHW format).
165-
video_metadata: Video metadata for timestamp computation.
166-
If not provided, computed from video processor.
167164
realtime_segments: List of ``{"start_sec": float, "text": str}``
168165
dicts extracted from assistant ``realtime_text`` content items.
169166
If None, this is treated as normal video QA.
170167
system_message: System prompt text.
171168
add_system_prompt: Whether to add a system prompt.
172-
**kwargs: Additional kwargs (e.g. ``fps``, ``do_sample_frames``).
169+
**kwargs: Forwarded to the model processor (e.g. ``fps``,
170+
``do_sample_frames``, ``video_metadata``).
173171
174172
Returns:
175173
Dict with ``input_ids``, ``text_stream_ids``, ``labels``, and
@@ -200,9 +198,6 @@ def process(
200198
_video_metadata = None
201199
if videos is not None:
202200
videos_kwargs = output_kwargs.get("videos_kwargs", {})
203-
videos_kwargs["return_metadata"] = True
204-
if video_metadata is not None:
205-
videos_kwargs["video_metadata"] = video_metadata
206201
video_inputs = self.processor.video_processor(videos=videos, return_tensors="pt", **videos_kwargs)
207202
video_grid_thw = video_inputs["video_grid_thw"]
208203
_video_metadata = video_inputs.pop("video_metadata")
@@ -239,6 +234,35 @@ def process(
239234
num_video_tokens = None
240235

241236
has_video = video_grid_thw is not None
237+
has_audio = bool(audio_inputs)
238+
239+
# Per-video audio token splits across video temporal chunks.
240+
# Required for envelope construction when both video and audio are
241+
# present (the inner ``<|audio_pad|>`` count of each per-chunk envelope).
242+
audio_per_chunk_per_video = None
243+
if has_video and has_audio:
244+
mel_lengths = audio_inputs["audio_attention_mask"].sum(-1)
245+
num_audio_tokens_list = [self.processor._get_num_audio_tokens(int(m.item())) for m in mel_lengths]
246+
temporal_patch_size = getattr(self.processor.video_processor, "temporal_patch_size", 2)
247+
audio_per_chunk_per_video = []
248+
for v_idx in range(len(video_grid_thw)):
249+
metadata = _video_metadata[v_idx]
250+
fps = metadata.fps if metadata.fps is not None else 24.0
251+
grid_t = int(video_grid_thw[v_idx][0])
252+
second_per_grid = temporal_patch_size / fps
253+
# Audio sample paired with this video by positional index
254+
a_idx = v_idx if v_idx < len(num_audio_tokens_list) else 0
255+
n_audio = num_audio_tokens_list[a_idx]
256+
audio_duration = self.processor._get_audio_duration_seconds(audio_inputs["audio_attention_mask"][a_idx])
257+
audio_rate = (n_audio / audio_duration) if audio_duration > 0 else 0.0
258+
audio_per_chunk_per_video.append(
259+
self.processor._split_audio_across_chunks(
260+
n_audio=n_audio,
261+
grid_t=grid_t,
262+
second_per_grid=second_per_grid,
263+
audio_rate=audio_rate,
264+
)
265+
)
242266

243267
# ==============================================================
244268
# 5. Build input_ids, text_stream_ids, labels
@@ -250,6 +274,8 @@ def process(
250274
num_video_tokens=num_video_tokens,
251275
video_grid_thw=video_grid_thw,
252276
video_metadata=_video_metadata,
277+
audio_per_chunk_per_video=audio_per_chunk_per_video,
278+
audio_attention_mask=audio_inputs.get("audio_attention_mask") if has_audio else None,
253279
system_message=system_message,
254280
add_system_prompt=add_system_prompt,
255281
)
@@ -259,17 +285,7 @@ def process(
259285
raise RuntimeError("Not implemented yet")
260286

261287
# ==============================================================
262-
# 6. Compute video_timestep and audio_timestep
263-
# ==============================================================
264-
if video_grid_thw is not None and _video_metadata is not None:
265-
inputs["video_timestep"] = self.processor._compute_video_timestep(video_grid_thw, _video_metadata)
266-
267-
if audio_inputs:
268-
audio_mask = audio_inputs["audio_attention_mask"]
269-
inputs["audio_timestep"] = self.processor._compute_audio_timestep(audio_mask)
270-
271-
# ==============================================================
272-
# 7. Pack vision/audio tensors into output
288+
# 6. Pack vision/audio tensors into output
273289
# ==============================================================
274290
if images is not None:
275291
inputs["pixel_values"] = image_inputs["pixel_values"]
@@ -297,26 +313,33 @@ def _build_normal_qa_ids_and_labels(
297313
num_video_tokens: Optional[List[int]],
298314
video_grid_thw=None,
299315
video_metadata=None,
316+
audio_per_chunk_per_video: Optional[List[List[int]]] = None,
317+
audio_attention_mask: Optional[torch.Tensor] = None,
300318
realtime_segments: Optional[List[Dict]] = None,
301319
system_message: str = "You are a helpful assistant",
302320
add_system_prompt: bool = True,
303321
) -> dict:
304322
"""Build input_ids, text_stream_ids, and labels from HF messages.
305323
306-
For normal video QA: text_stream_ids has rt_start/rt_pad/rt_speak
307-
with all rt_pad after rt_speak (model learns to stay silent).
324+
For normal video QA the text_stream_ids only differ from input_ids
325+
in the multimodal pad regions:
326+
- all ``<|video_pad|>`` and ``<|audio_pad|>`` slots → ``<|rt_pad|>``
327+
- first chunk's first ``<|video_pad|>`` → ``<|rt_start|>``
328+
- speak chunk's first ``<|audio_pad|>`` → ``<|rt_speak|>``
308329
309-
For realtime training: text_stream_ids has actual text tokens placed
310-
at the right temporal positions after rt_speak.
330+
Envelope boundary tokens (timestamps, vision_start/end,
331+
audio_start/end) keep their original ids in text_stream_ids so the
332+
LM sees the same special tokens it would in input_ids.
311333
"""
312334
results = self.get_qwen_template_labels(
313335
hf_messages,
314336
num_image_tokens,
315337
num_video_tokens,
316338
video_metadata,
317339
video_grid_thw,
318-
system_message,
319-
add_system_prompt,
340+
audio_per_chunk_per_video=audio_per_chunk_per_video,
341+
system_message=system_message,
342+
add_system_prompt=add_system_prompt,
320343
)
321344
input_id = results["input_ids"].tolist()
322345
target = results["labels"].tolist()
@@ -325,59 +348,53 @@ def _build_normal_qa_ids_and_labels(
325348
# Build text_stream_ids
326349
# ==============================================================
327350
has_video = video_grid_thw is not None
351+
has_audio = audio_attention_mask is not None
328352
text_stream_id = list(input_id) # start as a copy of input_ids
329353

330-
if has_video:
331-
vision_start_id = self.tokenizer.convert_tokens_to_ids(self.processor.vision_start_token)
332-
vision_end_id = self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token)
333-
temporal_patch_size = getattr(self.processor.video_processor, "temporal_patch_size", 2)
334-
335-
# Pre-compute per-frame timestamps for all videos
336-
all_frame_timestamps = []
337-
for v_idx in range(len(video_grid_thw)):
338-
metadata = video_metadata[v_idx]
339-
fps = metadata.fps if metadata.fps is not None else 24.0
340-
timestamps = self.processor._calculate_timestamps(metadata.frames_indices, fps, temporal_patch_size)
341-
all_frame_timestamps.extend(timestamps)
342-
343-
input_id_t = torch.tensor(input_id)
344-
vs_positions = (input_id_t == vision_start_id).nonzero(as_tuple=True)[0].tolist()
345-
ve_positions = (input_id_t == vision_end_id).nonzero(as_tuple=True)[0].tolist()
346-
347-
assert len(all_frame_timestamps) == len(vs_positions), "The timestamps and frame number should be equal"
348-
349-
# Find the first frame whose timestamp >= delay_seconds
350-
speak_frame = len(all_frame_timestamps) - 1 # fallback to last frame
351-
for idx, ts in enumerate(all_frame_timestamps):
352-
if ts >= self.processor.delay_seconds:
353-
speak_frame = idx
354-
break
355-
356-
# Fill text_stream_id for each frame's [VS][VP*N][VE] region
357-
for idx, (vs, ve) in enumerate(zip(vs_positions, ve_positions)):
358-
# VS and VE → rt_pad
359-
text_stream_id[vs] = self.rt_pad_id
360-
text_stream_id[ve] = self.rt_pad_id
361-
# VP region (vs+1 to ve-1) → rt_pad
362-
for k in range(vs + 1, ve):
363-
text_stream_id[k] = self.rt_pad_id
364-
# First frame: place rt_start at first VP position
365-
if idx == 0:
366-
text_stream_id[vs + 1] = self.rt_start_id
367-
# Delay frame: place rt_speak at first VP position
368-
if idx == speak_frame:
369-
text_stream_id[vs + 1] = self.rt_speak_id
354+
if has_video and has_audio:
355+
# video + audio: per-chunk envelope filler
356+
self.processor._fill_text_stream_video_audio(
357+
stream=text_stream_id,
358+
video_grid_thw=video_grid_thw,
359+
video_metadata=video_metadata,
360+
temporal_patch_size=getattr(self.processor.video_processor, "temporal_patch_size", 2),
361+
vision_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_start_token),
362+
vision_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token),
363+
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
364+
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
365+
video_pad_id=self.video_token_id,
366+
audio_pad_id=self.audio_token_id,
367+
rt_start_id=self.rt_start_id,
368+
rt_pad_id=self.rt_pad_id,
369+
rt_speak_id=self.rt_speak_id,
370+
)
371+
elif has_audio:
372+
# audio-only: single envelope per audio sample
373+
n_samples = audio_attention_mask.shape[0]
374+
for s_idx in range(n_samples):
375+
self.processor._fill_text_stream_audio_only(
376+
stream=text_stream_id,
377+
sample_idx=s_idx,
378+
audio_attention_mask=audio_attention_mask,
379+
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
380+
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
381+
audio_pad_id=self.audio_token_id,
382+
rt_start_id=self.rt_start_id,
383+
rt_pad_id=self.rt_pad_id,
384+
rt_speak_id=self.rt_speak_id,
385+
)
386+
# video-only (no audio): no text_stream_ids (matches processor)
370387

371388
input_id = torch.tensor(input_id, dtype=torch.long)
372389
target = torch.tensor(target, dtype=torch.long)
373-
text_stream_id = torch.tensor(text_stream_id, dtype=torch.long)
374390

375391
result = dict(
376392
input_ids=input_id,
377393
labels=target,
378394
)
379-
if has_video:
380-
result["text_stream_ids"] = text_stream_id
395+
# text_stream_ids only when audio is present (= streaming mode)
396+
if has_audio:
397+
result["text_stream_ids"] = torch.tensor(text_stream_id, dtype=torch.long)
381398

382399
return result
383400

@@ -388,6 +405,7 @@ def get_qwen_template_labels(
388405
num_video_tokens: List[int],
389406
video_metadata: List[dict],
390407
video_grid_thw=None,
408+
audio_per_chunk_per_video: Optional[List[List[int]]] = None,
391409
system_message: str = "You are a helpful assistant",
392410
add_system_prompt: bool = True,
393411
add_generation_prompt: bool = False,
@@ -426,6 +444,7 @@ def get_qwen_template_labels(
426444
video_start_from,
427445
curr_timestamp,
428446
video_grid_thw,
447+
audio_per_chunk_per_video=audio_per_chunk_per_video,
429448
)
430449
video_start_from += used_video
431450

@@ -449,6 +468,8 @@ def get_qwen_template_labels(
449468
target[idx] = -100
450469
if encode_id == self.video_token_id:
451470
target[idx] = -100
471+
if encode_id == self.audio_token_id:
472+
target[idx] = -100
452473

453474
input_id = torch.tensor(input_id, dtype=torch.long)
454475
target = torch.tensor(target, dtype=torch.long)
@@ -458,6 +479,77 @@ def get_qwen_template_labels(
458479
labels=target,
459480
)
460481

482+
def _expand_encode_id_video_tokens(
483+
self,
484+
encode_id: List[int],
485+
video_token_num: List[int],
486+
start_from: int = 0,
487+
curr_timestamp: List[float] = None,
488+
video_grid_thw=None,
489+
audio_per_chunk_per_video: Optional[List[List[int]]] = None,
490+
):
491+
"""Expand ``<|video_pad|>`` placeholders.
492+
493+
- Without audio: per-frame Qwen3VL legacy expansion (delegated to
494+
parent).
495+
- With audio: per-chunk envelope expansion matching the model
496+
processor's path 5b layout::
497+
498+
<t.t seconds><|vision_start|><|audio_start|>
499+
<|video_pad|>×spatial <|audio_pad|>×N_t
500+
<|audio_end|><|vision_end|>
501+
"""
502+
if audio_per_chunk_per_video is None:
503+
return super()._expand_encode_id_video_tokens(
504+
encode_id, video_token_num, start_from, curr_timestamp, video_grid_thw
505+
)
506+
507+
merge_length = self.processor.video_processor.merge_size**2
508+
vision_start_id = self.processor.vision_start_token_id
509+
vision_end_id = self.processor.vision_end_token_id
510+
audio_start_id = self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token)
511+
audio_end_id = self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token)
512+
temporal_patch_size = getattr(self.processor.video_processor, "temporal_patch_size", 2)
513+
514+
video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id]
515+
expanded_encode_id = []
516+
prev = 0
517+
for idx, pos in enumerate(video_pos):
518+
v_global = idx + start_from
519+
grid = video_grid_thw[v_global]
520+
grid_t = int(grid[0])
521+
spatial = int(grid[1:].prod() // merge_length)
522+
523+
# Figure out per-chunk audio counts; fps from grid (we only have
524+
# curr_timestamp which is per-frame timestamps in seconds). Use
525+
# them directly for the chunk start times.
526+
audio_per_chunk = audio_per_chunk_per_video[v_global]
527+
assert len(audio_per_chunk) == grid_t, f"audio_per_chunk len {len(audio_per_chunk)} != grid_t {grid_t}"
528+
529+
# Strip surrounding <|vision_start|> / <|vision_end|> from the
530+
# template (positions pos-1 and pos+1) -- we will emit our own.
531+
expanded_encode_id.extend(encode_id[prev : pos - 1])
532+
533+
for t in range(grid_t):
534+
# Per-frame timestamp (seconds) from the video metadata
535+
t_sec = curr_timestamp[t] if t < len(curr_timestamp) else (t * temporal_patch_size)
536+
timestamp_token_ids = self.processor.tokenizer.encode(f"<{t_sec:.1f} seconds>")
537+
n_audio_t = audio_per_chunk[t]
538+
expanded_encode_id.extend(timestamp_token_ids)
539+
expanded_encode_id.append(vision_start_id)
540+
expanded_encode_id.append(audio_start_id)
541+
expanded_encode_id.extend([self.video_token_id] * spatial)
542+
expanded_encode_id.extend([self.audio_token_id] * n_audio_t)
543+
expanded_encode_id.append(audio_end_id)
544+
expanded_encode_id.append(vision_end_id)
545+
546+
prev = pos + 2 # skip past original <|vision_end|>
547+
548+
if idx == len(video_pos) - 1:
549+
expanded_encode_id.extend(encode_id[prev:])
550+
551+
return expanded_encode_id, len(video_pos)
552+
461553
# ------------------------------------------------------------------
462554
# Chat template
463555
# ------------------------------------------------------------------
@@ -484,7 +576,7 @@ def chat_template(self):
484576
"{% for content in message['content'] %}"
485577
"{% if 'audio' in content or 'audio_url' in content %}"
486578
"{% set audio_count.value = audio_count.value + 1 %}"
487-
"<|AUDIO|>"
579+
"<|audio_pad|>"
488580
"{% elif content['type'] == 'image' or 'image' in content or 'image_url' in content %}"
489581
"{% set image_count.value = image_count.value + 1 %}"
490582
"<|vision_start|><|image_pad|><|vision_end|>"

0 commit comments

Comments
 (0)