Skip to content

Commit c96046d

Browse files
committed
feat(aero_realtime): add realtime stream labels
1 parent b68058d commit c96046d

4 files changed

Lines changed: 292 additions & 26 deletions

File tree

src/lmms_engine/datasets/iterable/aero_realtime_iterable_dataset.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,28 @@ def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
6868
if isinstance(messages, str):
6969
messages = json.loads(messages)
7070

71+
is_realtime = bool(data.get("realtime", False))
72+
7173
# First pass: collect media references and realtime segments
7274
for message in messages:
75+
message_time = message.get("time")
76+
if is_realtime and message_time is not None and message["role"] in ["user", "assistant"]:
77+
text = self._extract_text_content(message.get("content", []))
78+
if text:
79+
realtime_segments.append(
80+
{
81+
"time": float(message_time),
82+
"role": message["role"],
83+
"text": text,
84+
}
85+
)
86+
continue
87+
7388
for content in message["content"]:
74-
if content["type"] == "image_url":
89+
content_type = content.get("type")
90+
if content_type == "image_url":
7591
images_list.append(content["image_url"]["url"])
76-
elif content["type"] == "video_url":
92+
elif content_type == "video_url":
7793
video_url = content["video_url"]["url"]
7894
if data_folder is not None:
7995
video_path = os.path.join(data_folder, video_url)
@@ -90,10 +106,11 @@ def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
90106
kwargs["video_metadata"] = video_metadata
91107
kwargs["do_sample_frames"] = False
92108

93-
elif content["type"] == "realtime_text":
109+
elif content_type == "realtime_text":
94110
realtime_segments.append(
95111
{
96-
"start_sec": content["start_sec"],
112+
"time": content["start_sec"],
113+
"role": "assistant",
97114
"text": content["text"],
98115
}
99116
)
@@ -133,6 +150,16 @@ def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
133150
)
134151
return inputs
135152

153+
@staticmethod
154+
def _extract_text_content(content) -> str:
155+
if isinstance(content, str):
156+
return content
157+
texts = []
158+
for item in content:
159+
if item and item.get("type") == "text" and item.get("text"):
160+
texts.append(item["text"])
161+
return "\n".join(texts)
162+
136163
def _load_video_with_metadata(
137164
self,
138165
video_path: str,

src/lmms_engine/datasets/processor/aero_realtime_processor.py

Lines changed: 214 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,20 @@ def process(
249249
metadata = _video_metadata[v_idx]
250250
fps = metadata.fps if metadata.fps is not None else 24.0
251251
grid_t = int(video_grid_thw[v_idx][0])
252-
second_per_grid = temporal_patch_size / fps
252+
curr_timestamp = self.processor._calculate_timestamps(
253+
metadata.frames_indices,
254+
fps,
255+
temporal_patch_size,
256+
)
253257
# Audio sample paired with this video by positional index
254258
a_idx = v_idx if v_idx < len(num_audio_tokens_list) else 0
255259
n_audio = num_audio_tokens_list[a_idx]
256260
audio_duration = self.processor._get_audio_duration_seconds(audio_inputs["audio_attention_mask"][a_idx])
257261
audio_rate = (n_audio / audio_duration) if audio_duration > 0 else 0.0
258262
audio_per_chunk_per_video.append(
259-
self.processor._split_audio_across_chunks(
263+
self.processor._split_audio_across_chunk_times(
260264
n_audio=n_audio,
261-
grid_t=grid_t,
262-
second_per_grid=second_per_grid,
265+
chunk_start_times=curr_timestamp[:grid_t],
263266
audio_rate=audio_rate,
264267
)
265268
)
@@ -280,9 +283,18 @@ def process(
280283
add_system_prompt=add_system_prompt,
281284
)
282285
else:
283-
# TODO:
284-
# Build realtime qa ids and labels
285-
raise RuntimeError("Not implemented yet")
286+
inputs = self._build_realtime_ids_and_labels(
287+
hf_messages=hf_messages,
288+
num_image_tokens=num_image_tokens,
289+
num_video_tokens=num_video_tokens,
290+
video_grid_thw=video_grid_thw,
291+
video_metadata=_video_metadata,
292+
audio_per_chunk_per_video=audio_per_chunk_per_video,
293+
audio_attention_mask=audio_inputs.get("audio_attention_mask") if has_audio else None,
294+
realtime_segments=realtime_segments,
295+
system_message=system_message,
296+
add_system_prompt=add_system_prompt,
297+
)
286298

287299
# ==============================================================
288300
# 6. Pack vision/audio tensors into output
@@ -398,6 +410,184 @@ def _build_normal_qa_ids_and_labels(
398410

399411
return result
400412

413+
def _build_realtime_ids_and_labels(
414+
self,
415+
hf_messages,
416+
num_image_tokens: Optional[List[int]],
417+
num_video_tokens: Optional[List[int]],
418+
video_grid_thw=None,
419+
video_metadata=None,
420+
audio_per_chunk_per_video: Optional[List[List[int]]] = None,
421+
audio_attention_mask: Optional[torch.Tensor] = None,
422+
realtime_segments: Optional[List[Dict]] = None,
423+
system_message: str = "You are a helpful assistant",
424+
add_system_prompt: bool = True,
425+
) -> dict:
426+
if video_grid_thw is None or audio_per_chunk_per_video is None or audio_attention_mask is None:
427+
raise ValueError("Realtime training requires both video and audio inputs.")
428+
429+
base_messages, timed_user_segments = self._build_realtime_base_messages(
430+
hf_messages=hf_messages,
431+
realtime_segments=realtime_segments or [],
432+
video_grid_thw=video_grid_thw,
433+
video_metadata=video_metadata,
434+
audio_per_chunk_per_video=audio_per_chunk_per_video,
435+
system_message=system_message,
436+
add_system_prompt=add_system_prompt,
437+
)
438+
439+
results = self.get_qwen_template_labels(
440+
base_messages,
441+
num_image_tokens,
442+
num_video_tokens,
443+
video_metadata,
444+
video_grid_thw,
445+
audio_per_chunk_per_video=audio_per_chunk_per_video,
446+
timed_user_segments=timed_user_segments,
447+
system_message=system_message,
448+
add_system_prompt=False,
449+
)
450+
input_id = results["input_ids"].tolist()
451+
text_stream_id = list(input_id)
452+
target = [-100] * len(input_id)
453+
454+
self.processor._fill_text_stream_video_audio(
455+
stream=text_stream_id,
456+
video_grid_thw=video_grid_thw,
457+
video_metadata=video_metadata,
458+
temporal_patch_size=getattr(self.processor.video_processor, "temporal_patch_size", 2),
459+
vision_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_start_token),
460+
vision_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.vision_end_token),
461+
audio_start_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token),
462+
audio_end_id=self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token),
463+
video_pad_id=self.video_token_id,
464+
audio_pad_id=self.audio_token_id,
465+
rt_start_id=self.rt_start_id,
466+
rt_pad_id=self.rt_pad_id,
467+
rt_speak_id=self.rt_speak_id,
468+
)
469+
470+
audio_positions = [idx for idx, token_id in enumerate(input_id) if token_id == self.audio_token_id]
471+
audio_times = self._get_audio_position_times(
472+
video_grid_thw=video_grid_thw,
473+
video_metadata=video_metadata,
474+
audio_per_chunk_per_video=audio_per_chunk_per_video,
475+
)
476+
if len(audio_positions) != len(audio_times):
477+
raise ValueError(f"Audio position/time mismatch: {len(audio_positions)} != {len(audio_times)}")
478+
479+
delay = getattr(self.processor, "delay_seconds", 2.0)
480+
for pos, t_sec in zip(audio_positions, audio_times):
481+
if t_sec >= delay:
482+
target[pos] = self.rt_pad_id
483+
484+
assistant_segments = sorted(
485+
[seg for seg in (realtime_segments or []) if seg.get("role") == "assistant" and seg.get("text")],
486+
key=lambda item: float(item["time"]),
487+
)
488+
event_times = sorted(float(seg["time"]) for seg in (realtime_segments or []))
489+
for segment in assistant_segments:
490+
start_time = float(segment["time"])
491+
end_time = self._next_time_after(event_times, start_time)
492+
start_audio_idx = self._first_index_at_or_after(audio_times, start_time)
493+
end_audio_idx = (
494+
self._first_index_at_or_after(audio_times, end_time) if end_time is not None else len(audio_positions)
495+
)
496+
if start_audio_idx < end_audio_idx and text_stream_id[audio_positions[start_audio_idx]] == self.rt_speak_id:
497+
start_audio_idx += 1
498+
token_ids = self._encode_realtime_text(segment["text"])
499+
for offset, token_id in enumerate(token_ids[: max(0, end_audio_idx - start_audio_idx)]):
500+
pos = audio_positions[start_audio_idx + offset]
501+
text_stream_id[pos] = token_id
502+
target[pos] = token_id
503+
504+
input_tensor = torch.tensor(input_id, dtype=torch.long)
505+
text_stream_tensor = torch.tensor(text_stream_id, dtype=torch.long)
506+
target_tensor = torch.tensor(target, dtype=torch.long)
507+
508+
return dict(
509+
input_ids=input_tensor,
510+
labels=target_tensor,
511+
text_stream_ids=text_stream_tensor,
512+
)
513+
514+
def _build_realtime_base_messages(
515+
self,
516+
hf_messages,
517+
realtime_segments: List[Dict],
518+
video_grid_thw,
519+
video_metadata,
520+
audio_per_chunk_per_video: List[List[int]],
521+
system_message: str,
522+
add_system_prompt: bool,
523+
):
524+
messages = []
525+
first_content = []
526+
timed_user_segments = sorted(
527+
[seg for seg in realtime_segments if seg.get("role") == "user" and seg.get("text")],
528+
key=lambda item: float(item["time"]),
529+
)
530+
531+
if add_system_prompt and (not hf_messages or hf_messages[0]["role"] != "system"):
532+
messages.append({"role": "system", "content": [{"type": "text", "text": system_message}]})
533+
534+
for message in hf_messages:
535+
if message["role"] == "system":
536+
messages.append(message)
537+
continue
538+
if message.get("time") is not None:
539+
continue
540+
for content in message["content"]:
541+
if content.get("type") in ["image", "video", "audio"]:
542+
first_content.append(content)
543+
544+
content = []
545+
content.extend(first_content)
546+
547+
messages.append({"role": "user", "content": content})
548+
return messages, timed_user_segments
549+
550+
def _get_chunk_start_times(self, video_grid_thw, video_metadata, audio_per_chunk_per_video: List[List[int]]):
551+
times = []
552+
for v_idx in range(len(video_grid_thw)):
553+
metadata = video_metadata[v_idx]
554+
fps = metadata.fps if metadata.fps is not None else 24.0
555+
curr_timestamp = self.processor._calculate_timestamps(
556+
metadata.frames_indices,
557+
fps,
558+
self.processor.video_processor.temporal_patch_size,
559+
)
560+
for t in range(len(audio_per_chunk_per_video[v_idx])):
561+
times.append(curr_timestamp[t] if t < len(curr_timestamp) else curr_timestamp[-1])
562+
return times
563+
564+
def _get_audio_position_times(self, video_grid_thw, video_metadata, audio_per_chunk_per_video: List[List[int]]):
565+
times = []
566+
chunk_times = self._get_chunk_start_times(video_grid_thw, video_metadata, audio_per_chunk_per_video)
567+
chunk_idx = 0
568+
for audio_per_chunk in audio_per_chunk_per_video:
569+
for n_audio in audio_per_chunk:
570+
times.extend([chunk_times[chunk_idx]] * n_audio)
571+
chunk_idx += 1
572+
return times
573+
574+
def _encode_realtime_text(self, text: str) -> List[int]:
575+
return self.tokenizer.encode(text, add_special_tokens=False)
576+
577+
@staticmethod
578+
def _first_index_at_or_after(values: List[float], target: float) -> int:
579+
for idx, value in enumerate(values):
580+
if value >= target:
581+
return idx
582+
return len(values)
583+
584+
@staticmethod
585+
def _next_time_after(values: List[float], target: float) -> Optional[float]:
586+
for value in values:
587+
if value > target:
588+
return value
589+
return None
590+
401591
def get_qwen_template_labels(
402592
self,
403593
hf_messages,
@@ -406,6 +596,7 @@ def get_qwen_template_labels(
406596
video_metadata: List[dict],
407597
video_grid_thw=None,
408598
audio_per_chunk_per_video: Optional[List[List[int]]] = None,
599+
timed_user_segments: Optional[List[Dict]] = None,
409600
system_message: str = "You are a helpful assistant",
410601
add_system_prompt: bool = True,
411602
add_generation_prompt: bool = False,
@@ -445,6 +636,7 @@ def get_qwen_template_labels(
445636
curr_timestamp,
446637
video_grid_thw,
447638
audio_per_chunk_per_video=audio_per_chunk_per_video,
639+
timed_user_segments=timed_user_segments,
448640
)
449641
video_start_from += used_video
450642

@@ -487,6 +679,7 @@ def _expand_encode_id_video_tokens(
487679
curr_timestamp: List[float] = None,
488680
video_grid_thw=None,
489681
audio_per_chunk_per_video: Optional[List[List[int]]] = None,
682+
timed_user_segments: Optional[List[Dict]] = None,
490683
):
491684
"""Expand ``<|video_pad|>`` placeholders.
492685
@@ -510,6 +703,7 @@ def _expand_encode_id_video_tokens(
510703
audio_start_id = self.tokenizer.convert_tokens_to_ids(self.processor.audio_start_token)
511704
audio_end_id = self.tokenizer.convert_tokens_to_ids(self.processor.audio_end_token)
512705
temporal_patch_size = getattr(self.processor.video_processor, "temporal_patch_size", 2)
706+
timed_user_segments = timed_user_segments or []
513707

514708
video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id]
515709
expanded_encode_id = []
@@ -525,14 +719,26 @@ def _expand_encode_id_video_tokens(
525719
# them directly for the chunk start times.
526720
audio_per_chunk = audio_per_chunk_per_video[v_global]
527721
assert len(audio_per_chunk) == grid_t, f"audio_per_chunk len {len(audio_per_chunk)} != grid_t {grid_t}"
722+
chunk_times = [
723+
curr_timestamp[t] if t < len(curr_timestamp) else (t * temporal_patch_size) for t in range(grid_t)
724+
]
725+
user_by_chunk = [[] for _ in range(grid_t)]
726+
for segment in timed_user_segments:
727+
chunk_idx = self._first_index_at_or_after(chunk_times, float(segment["time"]))
728+
if chunk_idx >= grid_t:
729+
chunk_idx = grid_t - 1
730+
user_by_chunk[chunk_idx].append(segment["text"])
528731

529732
# Strip surrounding <|vision_start|> / <|vision_end|> from the
530733
# template (positions pos-1 and pos+1) -- we will emit our own.
531734
expanded_encode_id.extend(encode_id[prev : pos - 1])
532735

533736
for t in range(grid_t):
737+
for user_text in user_by_chunk[t]:
738+
expanded_encode_id.extend(self._encode_realtime_text(user_text))
739+
534740
# Per-frame timestamp (seconds) from the video metadata
535-
t_sec = curr_timestamp[t] if t < len(curr_timestamp) else (t * temporal_patch_size)
741+
t_sec = chunk_times[t]
536742
timestamp_token_ids = self.processor.tokenizer.encode(f"<{t_sec:.1f} seconds>")
537743
n_audio_t = audio_per_chunk[t]
538744
expanded_encode_id.extend(timestamp_token_ids)

0 commit comments

Comments
 (0)