@@ -249,17 +249,20 @@ def process(
249249 metadata = _video_metadata [v_idx ]
250250 fps = metadata .fps if metadata .fps is not None else 24.0
251251 grid_t = int (video_grid_thw [v_idx ][0 ])
252- second_per_grid = temporal_patch_size / fps
252+ curr_timestamp = self .processor ._calculate_timestamps (
253+ metadata .frames_indices ,
254+ fps ,
255+ temporal_patch_size ,
256+ )
253257 # Audio sample paired with this video by positional index
254258 a_idx = v_idx if v_idx < len (num_audio_tokens_list ) else 0
255259 n_audio = num_audio_tokens_list [a_idx ]
256260 audio_duration = self .processor ._get_audio_duration_seconds (audio_inputs ["audio_attention_mask" ][a_idx ])
257261 audio_rate = (n_audio / audio_duration ) if audio_duration > 0 else 0.0
258262 audio_per_chunk_per_video .append (
259- self .processor ._split_audio_across_chunks (
263+ self .processor ._split_audio_across_chunk_times (
260264 n_audio = n_audio ,
261- grid_t = grid_t ,
262- second_per_grid = second_per_grid ,
265+ chunk_start_times = curr_timestamp [:grid_t ],
263266 audio_rate = audio_rate ,
264267 )
265268 )
@@ -280,9 +283,18 @@ def process(
280283 add_system_prompt = add_system_prompt ,
281284 )
282285 else :
283- # TODO:
284- # Build realtime qa ids and labels
285- raise RuntimeError ("Not implemented yet" )
286+ inputs = self ._build_realtime_ids_and_labels (
287+ hf_messages = hf_messages ,
288+ num_image_tokens = num_image_tokens ,
289+ num_video_tokens = num_video_tokens ,
290+ video_grid_thw = video_grid_thw ,
291+ video_metadata = _video_metadata ,
292+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
293+ audio_attention_mask = audio_inputs .get ("audio_attention_mask" ) if has_audio else None ,
294+ realtime_segments = realtime_segments ,
295+ system_message = system_message ,
296+ add_system_prompt = add_system_prompt ,
297+ )
286298
287299 # ==============================================================
288300 # 6. Pack vision/audio tensors into output
@@ -398,6 +410,184 @@ def _build_normal_qa_ids_and_labels(
398410
399411 return result
400412
413+ def _build_realtime_ids_and_labels (
414+ self ,
415+ hf_messages ,
416+ num_image_tokens : Optional [List [int ]],
417+ num_video_tokens : Optional [List [int ]],
418+ video_grid_thw = None ,
419+ video_metadata = None ,
420+ audio_per_chunk_per_video : Optional [List [List [int ]]] = None ,
421+ audio_attention_mask : Optional [torch .Tensor ] = None ,
422+ realtime_segments : Optional [List [Dict ]] = None ,
423+ system_message : str = "You are a helpful assistant" ,
424+ add_system_prompt : bool = True ,
425+ ) -> dict :
426+ if video_grid_thw is None or audio_per_chunk_per_video is None or audio_attention_mask is None :
427+ raise ValueError ("Realtime training requires both video and audio inputs." )
428+
429+ base_messages , timed_user_segments = self ._build_realtime_base_messages (
430+ hf_messages = hf_messages ,
431+ realtime_segments = realtime_segments or [],
432+ video_grid_thw = video_grid_thw ,
433+ video_metadata = video_metadata ,
434+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
435+ system_message = system_message ,
436+ add_system_prompt = add_system_prompt ,
437+ )
438+
439+ results = self .get_qwen_template_labels (
440+ base_messages ,
441+ num_image_tokens ,
442+ num_video_tokens ,
443+ video_metadata ,
444+ video_grid_thw ,
445+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
446+ timed_user_segments = timed_user_segments ,
447+ system_message = system_message ,
448+ add_system_prompt = False ,
449+ )
450+ input_id = results ["input_ids" ].tolist ()
451+ text_stream_id = list (input_id )
452+ target = [- 100 ] * len (input_id )
453+
454+ self .processor ._fill_text_stream_video_audio (
455+ stream = text_stream_id ,
456+ video_grid_thw = video_grid_thw ,
457+ video_metadata = video_metadata ,
458+ temporal_patch_size = getattr (self .processor .video_processor , "temporal_patch_size" , 2 ),
459+ vision_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .vision_start_token ),
460+ vision_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .vision_end_token ),
461+ audio_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_start_token ),
462+ audio_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_end_token ),
463+ video_pad_id = self .video_token_id ,
464+ audio_pad_id = self .audio_token_id ,
465+ rt_start_id = self .rt_start_id ,
466+ rt_pad_id = self .rt_pad_id ,
467+ rt_speak_id = self .rt_speak_id ,
468+ )
469+
470+ audio_positions = [idx for idx , token_id in enumerate (input_id ) if token_id == self .audio_token_id ]
471+ audio_times = self ._get_audio_position_times (
472+ video_grid_thw = video_grid_thw ,
473+ video_metadata = video_metadata ,
474+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
475+ )
476+ if len (audio_positions ) != len (audio_times ):
477+ raise ValueError (f"Audio position/time mismatch: { len (audio_positions )} != { len (audio_times )} " )
478+
479+ delay = getattr (self .processor , "delay_seconds" , 2.0 )
480+ for pos , t_sec in zip (audio_positions , audio_times ):
481+ if t_sec >= delay :
482+ target [pos ] = self .rt_pad_id
483+
484+ assistant_segments = sorted (
485+ [seg for seg in (realtime_segments or []) if seg .get ("role" ) == "assistant" and seg .get ("text" )],
486+ key = lambda item : float (item ["time" ]),
487+ )
488+ event_times = sorted (float (seg ["time" ]) for seg in (realtime_segments or []))
489+ for segment in assistant_segments :
490+ start_time = float (segment ["time" ])
491+ end_time = self ._next_time_after (event_times , start_time )
492+ start_audio_idx = self ._first_index_at_or_after (audio_times , start_time )
493+ end_audio_idx = (
494+ self ._first_index_at_or_after (audio_times , end_time ) if end_time is not None else len (audio_positions )
495+ )
496+ if start_audio_idx < end_audio_idx and text_stream_id [audio_positions [start_audio_idx ]] == self .rt_speak_id :
497+ start_audio_idx += 1
498+ token_ids = self ._encode_realtime_text (segment ["text" ])
499+ for offset , token_id in enumerate (token_ids [: max (0 , end_audio_idx - start_audio_idx )]):
500+ pos = audio_positions [start_audio_idx + offset ]
501+ text_stream_id [pos ] = token_id
502+ target [pos ] = token_id
503+
504+ input_tensor = torch .tensor (input_id , dtype = torch .long )
505+ text_stream_tensor = torch .tensor (text_stream_id , dtype = torch .long )
506+ target_tensor = torch .tensor (target , dtype = torch .long )
507+
508+ return dict (
509+ input_ids = input_tensor ,
510+ labels = target_tensor ,
511+ text_stream_ids = text_stream_tensor ,
512+ )
513+
514+ def _build_realtime_base_messages (
515+ self ,
516+ hf_messages ,
517+ realtime_segments : List [Dict ],
518+ video_grid_thw ,
519+ video_metadata ,
520+ audio_per_chunk_per_video : List [List [int ]],
521+ system_message : str ,
522+ add_system_prompt : bool ,
523+ ):
524+ messages = []
525+ first_content = []
526+ timed_user_segments = sorted (
527+ [seg for seg in realtime_segments if seg .get ("role" ) == "user" and seg .get ("text" )],
528+ key = lambda item : float (item ["time" ]),
529+ )
530+
531+ if add_system_prompt and (not hf_messages or hf_messages [0 ]["role" ] != "system" ):
532+ messages .append ({"role" : "system" , "content" : [{"type" : "text" , "text" : system_message }]})
533+
534+ for message in hf_messages :
535+ if message ["role" ] == "system" :
536+ messages .append (message )
537+ continue
538+ if message .get ("time" ) is not None :
539+ continue
540+ for content in message ["content" ]:
541+ if content .get ("type" ) in ["image" , "video" , "audio" ]:
542+ first_content .append (content )
543+
544+ content = []
545+ content .extend (first_content )
546+
547+ messages .append ({"role" : "user" , "content" : content })
548+ return messages , timed_user_segments
549+
550+ def _get_chunk_start_times (self , video_grid_thw , video_metadata , audio_per_chunk_per_video : List [List [int ]]):
551+ times = []
552+ for v_idx in range (len (video_grid_thw )):
553+ metadata = video_metadata [v_idx ]
554+ fps = metadata .fps if metadata .fps is not None else 24.0
555+ curr_timestamp = self .processor ._calculate_timestamps (
556+ metadata .frames_indices ,
557+ fps ,
558+ self .processor .video_processor .temporal_patch_size ,
559+ )
560+ for t in range (len (audio_per_chunk_per_video [v_idx ])):
561+ times .append (curr_timestamp [t ] if t < len (curr_timestamp ) else curr_timestamp [- 1 ])
562+ return times
563+
564+ def _get_audio_position_times (self , video_grid_thw , video_metadata , audio_per_chunk_per_video : List [List [int ]]):
565+ times = []
566+ chunk_times = self ._get_chunk_start_times (video_grid_thw , video_metadata , audio_per_chunk_per_video )
567+ chunk_idx = 0
568+ for audio_per_chunk in audio_per_chunk_per_video :
569+ for n_audio in audio_per_chunk :
570+ times .extend ([chunk_times [chunk_idx ]] * n_audio )
571+ chunk_idx += 1
572+ return times
573+
574+ def _encode_realtime_text (self , text : str ) -> List [int ]:
575+ return self .tokenizer .encode (text , add_special_tokens = False )
576+
577+ @staticmethod
578+ def _first_index_at_or_after (values : List [float ], target : float ) -> int :
579+ for idx , value in enumerate (values ):
580+ if value >= target :
581+ return idx
582+ return len (values )
583+
584+ @staticmethod
585+ def _next_time_after (values : List [float ], target : float ) -> Optional [float ]:
586+ for value in values :
587+ if value > target :
588+ return value
589+ return None
590+
401591 def get_qwen_template_labels (
402592 self ,
403593 hf_messages ,
@@ -406,6 +596,7 @@ def get_qwen_template_labels(
406596 video_metadata : List [dict ],
407597 video_grid_thw = None ,
408598 audio_per_chunk_per_video : Optional [List [List [int ]]] = None ,
599+ timed_user_segments : Optional [List [Dict ]] = None ,
409600 system_message : str = "You are a helpful assistant" ,
410601 add_system_prompt : bool = True ,
411602 add_generation_prompt : bool = False ,
@@ -445,6 +636,7 @@ def get_qwen_template_labels(
445636 curr_timestamp ,
446637 video_grid_thw ,
447638 audio_per_chunk_per_video = audio_per_chunk_per_video ,
639+ timed_user_segments = timed_user_segments ,
448640 )
449641 video_start_from += used_video
450642
@@ -487,6 +679,7 @@ def _expand_encode_id_video_tokens(
487679 curr_timestamp : List [float ] = None ,
488680 video_grid_thw = None ,
489681 audio_per_chunk_per_video : Optional [List [List [int ]]] = None ,
682+ timed_user_segments : Optional [List [Dict ]] = None ,
490683 ):
491684 """Expand ``<|video_pad|>`` placeholders.
492685
@@ -510,6 +703,7 @@ def _expand_encode_id_video_tokens(
510703 audio_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_start_token )
511704 audio_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_end_token )
512705 temporal_patch_size = getattr (self .processor .video_processor , "temporal_patch_size" , 2 )
706+ timed_user_segments = timed_user_segments or []
513707
514708 video_pos = [i for i , x in enumerate (encode_id ) if x == self .video_token_id ]
515709 expanded_encode_id = []
@@ -525,14 +719,26 @@ def _expand_encode_id_video_tokens(
525719 # them directly for the chunk start times.
526720 audio_per_chunk = audio_per_chunk_per_video [v_global ]
527721 assert len (audio_per_chunk ) == grid_t , f"audio_per_chunk len { len (audio_per_chunk )} != grid_t { grid_t } "
722+ chunk_times = [
723+ curr_timestamp [t ] if t < len (curr_timestamp ) else (t * temporal_patch_size ) for t in range (grid_t )
724+ ]
725+ user_by_chunk = [[] for _ in range (grid_t )]
726+ for segment in timed_user_segments :
727+ chunk_idx = self ._first_index_at_or_after (chunk_times , float (segment ["time" ]))
728+ if chunk_idx >= grid_t :
729+ chunk_idx = grid_t - 1
730+ user_by_chunk [chunk_idx ].append (segment ["text" ])
528731
529732 # Strip surrounding <|vision_start|> / <|vision_end|> from the
530733 # template (positions pos-1 and pos+1) -- we will emit our own.
531734 expanded_encode_id .extend (encode_id [prev : pos - 1 ])
532735
533736 for t in range (grid_t ):
737+ for user_text in user_by_chunk [t ]:
738+ expanded_encode_id .extend (self ._encode_realtime_text (user_text ))
739+
534740 # Per-frame timestamp (seconds) from the video metadata
535- t_sec = curr_timestamp [t ] if t < len ( curr_timestamp ) else ( t * temporal_patch_size )
741+ t_sec = chunk_times [t ]
536742 timestamp_token_ids = self .processor .tokenizer .encode (f"<{ t_sec :.1f} seconds>" )
537743 n_audio_t = audio_per_chunk [t ]
538744 expanded_encode_id .extend (timestamp_token_ids )
0 commit comments