@@ -146,7 +146,6 @@ def process(
146146 audios : Optional [List [np .ndarray ]] = None ,
147147 sampling_rate : Optional [int ] = None ,
148148 videos = None ,
149- video_metadata = None ,
150149 realtime_segments : Optional [List [Dict ]] = None ,
151150 system_message : str = "You are a helpful assistant" ,
152151 add_system_prompt = True ,
@@ -162,14 +161,13 @@ def process(
162161 audios: List of audio waveforms (mono, float32, at sampling_rate).
163162 sampling_rate: Audio sampling rate.
164163 videos: List of video frames (numpy arrays, TCHW format).
165- video_metadata: Video metadata for timestamp computation.
166- If not provided, computed from video processor.
167164 realtime_segments: List of ``{"start_sec": float, "text": str}``
168165 dicts extracted from assistant ``realtime_text`` content items.
169166 If None, this is treated as normal video QA.
170167 system_message: System prompt text.
171168 add_system_prompt: Whether to add a system prompt.
172- **kwargs: Additional kwargs (e.g. ``fps``, ``do_sample_frames``).
169+ **kwargs: Forwarded to the model processor (e.g. ``fps``,
170+ ``do_sample_frames``, ``video_metadata``).
173171
174172 Returns:
175173 Dict with ``input_ids``, ``text_stream_ids``, ``labels``, and
@@ -200,9 +198,6 @@ def process(
200198 _video_metadata = None
201199 if videos is not None :
202200 videos_kwargs = output_kwargs .get ("videos_kwargs" , {})
203- videos_kwargs ["return_metadata" ] = True
204- if video_metadata is not None :
205- videos_kwargs ["video_metadata" ] = video_metadata
206201 video_inputs = self .processor .video_processor (videos = videos , return_tensors = "pt" , ** videos_kwargs )
207202 video_grid_thw = video_inputs ["video_grid_thw" ]
208203 _video_metadata = video_inputs .pop ("video_metadata" )
@@ -239,6 +234,35 @@ def process(
239234 num_video_tokens = None
240235
241236 has_video = video_grid_thw is not None
237+ has_audio = bool (audio_inputs )
238+
239+ # Per-video audio token splits across video temporal chunks.
240+ # Required for envelope construction when both video and audio are
241+ # present (the inner ``<|audio_pad|>`` count of each per-chunk envelope).
242+ audio_per_chunk_per_video = None
243+ if has_video and has_audio :
244+ mel_lengths = audio_inputs ["audio_attention_mask" ].sum (- 1 )
245+ num_audio_tokens_list = [self .processor ._get_num_audio_tokens (int (m .item ())) for m in mel_lengths ]
246+ temporal_patch_size = getattr (self .processor .video_processor , "temporal_patch_size" , 2 )
247+ audio_per_chunk_per_video = []
248+ for v_idx in range (len (video_grid_thw )):
249+ metadata = _video_metadata [v_idx ]
250+ fps = metadata .fps if metadata .fps is not None else 24.0
251+ grid_t = int (video_grid_thw [v_idx ][0 ])
252+ second_per_grid = temporal_patch_size / fps
253+ # Audio sample paired with this video by positional index
254+ a_idx = v_idx if v_idx < len (num_audio_tokens_list ) else 0
255+ n_audio = num_audio_tokens_list [a_idx ]
256+ audio_duration = self .processor ._get_audio_duration_seconds (audio_inputs ["audio_attention_mask" ][a_idx ])
257+ audio_rate = (n_audio / audio_duration ) if audio_duration > 0 else 0.0
258+ audio_per_chunk_per_video .append (
259+ self .processor ._split_audio_across_chunks (
260+ n_audio = n_audio ,
261+ grid_t = grid_t ,
262+ second_per_grid = second_per_grid ,
263+ audio_rate = audio_rate ,
264+ )
265+ )
242266
243267 # ==============================================================
244268 # 5. Build input_ids, text_stream_ids, labels
@@ -250,6 +274,8 @@ def process(
250274 num_video_tokens = num_video_tokens ,
251275 video_grid_thw = video_grid_thw ,
252276 video_metadata = _video_metadata ,
277+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
278+ audio_attention_mask = audio_inputs .get ("audio_attention_mask" ) if has_audio else None ,
253279 system_message = system_message ,
254280 add_system_prompt = add_system_prompt ,
255281 )
@@ -259,17 +285,7 @@ def process(
259285 raise RuntimeError ("Not implemented yet" )
260286
261287 # ==============================================================
262- # 6. Compute video_timestep and audio_timestep
263- # ==============================================================
264- if video_grid_thw is not None and _video_metadata is not None :
265- inputs ["video_timestep" ] = self .processor ._compute_video_timestep (video_grid_thw , _video_metadata )
266-
267- if audio_inputs :
268- audio_mask = audio_inputs ["audio_attention_mask" ]
269- inputs ["audio_timestep" ] = self .processor ._compute_audio_timestep (audio_mask )
270-
271- # ==============================================================
272- # 7. Pack vision/audio tensors into output
288+ # 6. Pack vision/audio tensors into output
273289 # ==============================================================
274290 if images is not None :
275291 inputs ["pixel_values" ] = image_inputs ["pixel_values" ]
@@ -297,26 +313,33 @@ def _build_normal_qa_ids_and_labels(
297313 num_video_tokens : Optional [List [int ]],
298314 video_grid_thw = None ,
299315 video_metadata = None ,
316+ audio_per_chunk_per_video : Optional [List [List [int ]]] = None ,
317+ audio_attention_mask : Optional [torch .Tensor ] = None ,
300318 realtime_segments : Optional [List [Dict ]] = None ,
301319 system_message : str = "You are a helpful assistant" ,
302320 add_system_prompt : bool = True ,
303321 ) -> dict :
304322 """Build input_ids, text_stream_ids, and labels from HF messages.
305323
306- For normal video QA: text_stream_ids has rt_start/rt_pad/rt_speak
307- with all rt_pad after rt_speak (model learns to stay silent).
324+ For normal video QA the text_stream_ids only differ from input_ids
325+ in the multimodal pad regions:
326+ - all ``<|video_pad|>`` and ``<|audio_pad|>`` slots → ``<|rt_pad|>``
327+ - first chunk's first ``<|video_pad|>`` → ``<|rt_start|>``
328+ - speak chunk's first ``<|audio_pad|>`` → ``<|rt_speak|>``
308329
309- For realtime training: text_stream_ids has actual text tokens placed
310- at the right temporal positions after rt_speak.
330+ Envelope boundary tokens (timestamps, vision_start/end,
331+ audio_start/end) keep their original ids in text_stream_ids so the
332+ LM sees the same special tokens it would in input_ids.
311333 """
312334 results = self .get_qwen_template_labels (
313335 hf_messages ,
314336 num_image_tokens ,
315337 num_video_tokens ,
316338 video_metadata ,
317339 video_grid_thw ,
318- system_message ,
319- add_system_prompt ,
340+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
341+ system_message = system_message ,
342+ add_system_prompt = add_system_prompt ,
320343 )
321344 input_id = results ["input_ids" ].tolist ()
322345 target = results ["labels" ].tolist ()
@@ -325,59 +348,53 @@ def _build_normal_qa_ids_and_labels(
325348 # Build text_stream_ids
326349 # ==============================================================
327350 has_video = video_grid_thw is not None
351+ has_audio = audio_attention_mask is not None
328352 text_stream_id = list (input_id ) # start as a copy of input_ids
329353
330- if has_video :
331- vision_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .vision_start_token )
332- vision_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .vision_end_token )
333- temporal_patch_size = getattr (self .processor .video_processor , "temporal_patch_size" , 2 )
334-
335- # Pre-compute per-frame timestamps for all videos
336- all_frame_timestamps = []
337- for v_idx in range (len (video_grid_thw )):
338- metadata = video_metadata [v_idx ]
339- fps = metadata .fps if metadata .fps is not None else 24.0
340- timestamps = self .processor ._calculate_timestamps (metadata .frames_indices , fps , temporal_patch_size )
341- all_frame_timestamps .extend (timestamps )
342-
343- input_id_t = torch .tensor (input_id )
344- vs_positions = (input_id_t == vision_start_id ).nonzero (as_tuple = True )[0 ].tolist ()
345- ve_positions = (input_id_t == vision_end_id ).nonzero (as_tuple = True )[0 ].tolist ()
346-
347- assert len (all_frame_timestamps ) == len (vs_positions ), "The timestamps and frame number should be equal"
348-
349- # Find the first frame whose timestamp >= delay_seconds
350- speak_frame = len (all_frame_timestamps ) - 1 # fallback to last frame
351- for idx , ts in enumerate (all_frame_timestamps ):
352- if ts >= self .processor .delay_seconds :
353- speak_frame = idx
354- break
355-
356- # Fill text_stream_id for each frame's [VS][VP*N][VE] region
357- for idx , (vs , ve ) in enumerate (zip (vs_positions , ve_positions )):
358- # VS and VE → rt_pad
359- text_stream_id [vs ] = self .rt_pad_id
360- text_stream_id [ve ] = self .rt_pad_id
361- # VP region (vs+1 to ve-1) → rt_pad
362- for k in range (vs + 1 , ve ):
363- text_stream_id [k ] = self .rt_pad_id
364- # First frame: place rt_start at first VP position
365- if idx == 0 :
366- text_stream_id [vs + 1 ] = self .rt_start_id
367- # Delay frame: place rt_speak at first VP position
368- if idx == speak_frame :
369- text_stream_id [vs + 1 ] = self .rt_speak_id
354+ if has_video and has_audio :
355+ # video + audio: per-chunk envelope filler
356+ self .processor ._fill_text_stream_video_audio (
357+ stream = text_stream_id ,
358+ video_grid_thw = video_grid_thw ,
359+ video_metadata = video_metadata ,
360+ temporal_patch_size = getattr (self .processor .video_processor , "temporal_patch_size" , 2 ),
361+ vision_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .vision_start_token ),
362+ vision_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .vision_end_token ),
363+ audio_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_start_token ),
364+ audio_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_end_token ),
365+ video_pad_id = self .video_token_id ,
366+ audio_pad_id = self .audio_token_id ,
367+ rt_start_id = self .rt_start_id ,
368+ rt_pad_id = self .rt_pad_id ,
369+ rt_speak_id = self .rt_speak_id ,
370+ )
371+ elif has_audio :
372+ # audio-only: single envelope per audio sample
373+ n_samples = audio_attention_mask .shape [0 ]
374+ for s_idx in range (n_samples ):
375+ self .processor ._fill_text_stream_audio_only (
376+ stream = text_stream_id ,
377+ sample_idx = s_idx ,
378+ audio_attention_mask = audio_attention_mask ,
379+ audio_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_start_token ),
380+ audio_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_end_token ),
381+ audio_pad_id = self .audio_token_id ,
382+ rt_start_id = self .rt_start_id ,
383+ rt_pad_id = self .rt_pad_id ,
384+ rt_speak_id = self .rt_speak_id ,
385+ )
386+ # video-only (no audio): no text_stream_ids (matches processor)
370387
371388 input_id = torch .tensor (input_id , dtype = torch .long )
372389 target = torch .tensor (target , dtype = torch .long )
373- text_stream_id = torch .tensor (text_stream_id , dtype = torch .long )
374390
375391 result = dict (
376392 input_ids = input_id ,
377393 labels = target ,
378394 )
379- if has_video :
380- result ["text_stream_ids" ] = text_stream_id
395+ # text_stream_ids only when audio is present (= streaming mode)
396+ if has_audio :
397+ result ["text_stream_ids" ] = torch .tensor (text_stream_id , dtype = torch .long )
381398
382399 return result
383400
@@ -388,6 +405,7 @@ def get_qwen_template_labels(
388405 num_video_tokens : List [int ],
389406 video_metadata : List [dict ],
390407 video_grid_thw = None ,
408+ audio_per_chunk_per_video : Optional [List [List [int ]]] = None ,
391409 system_message : str = "You are a helpful assistant" ,
392410 add_system_prompt : bool = True ,
393411 add_generation_prompt : bool = False ,
@@ -426,6 +444,7 @@ def get_qwen_template_labels(
426444 video_start_from ,
427445 curr_timestamp ,
428446 video_grid_thw ,
447+ audio_per_chunk_per_video = audio_per_chunk_per_video ,
429448 )
430449 video_start_from += used_video
431450
@@ -449,6 +468,8 @@ def get_qwen_template_labels(
449468 target [idx ] = - 100
450469 if encode_id == self .video_token_id :
451470 target [idx ] = - 100
471+ if encode_id == self .audio_token_id :
472+ target [idx ] = - 100
452473
453474 input_id = torch .tensor (input_id , dtype = torch .long )
454475 target = torch .tensor (target , dtype = torch .long )
@@ -458,6 +479,77 @@ def get_qwen_template_labels(
458479 labels = target ,
459480 )
460481
482+ def _expand_encode_id_video_tokens (
483+ self ,
484+ encode_id : List [int ],
485+ video_token_num : List [int ],
486+ start_from : int = 0 ,
487+ curr_timestamp : List [float ] = None ,
488+ video_grid_thw = None ,
489+ audio_per_chunk_per_video : Optional [List [List [int ]]] = None ,
490+ ):
491+ """Expand ``<|video_pad|>`` placeholders.
492+
493+ - Without audio: per-frame Qwen3VL legacy expansion (delegated to
494+ parent).
495+ - With audio: per-chunk envelope expansion matching the model
496+ processor's path 5b layout::
497+
498+ <t.t seconds><|vision_start|><|audio_start|>
499+ <|video_pad|>×spatial <|audio_pad|>×N_t
500+ <|audio_end|><|vision_end|>
501+ """
502+ if audio_per_chunk_per_video is None :
503+ return super ()._expand_encode_id_video_tokens (
504+ encode_id , video_token_num , start_from , curr_timestamp , video_grid_thw
505+ )
506+
507+ merge_length = self .processor .video_processor .merge_size ** 2
508+ vision_start_id = self .processor .vision_start_token_id
509+ vision_end_id = self .processor .vision_end_token_id
510+ audio_start_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_start_token )
511+ audio_end_id = self .tokenizer .convert_tokens_to_ids (self .processor .audio_end_token )
512+ temporal_patch_size = getattr (self .processor .video_processor , "temporal_patch_size" , 2 )
513+
514+ video_pos = [i for i , x in enumerate (encode_id ) if x == self .video_token_id ]
515+ expanded_encode_id = []
516+ prev = 0
517+ for idx , pos in enumerate (video_pos ):
518+ v_global = idx + start_from
519+ grid = video_grid_thw [v_global ]
520+ grid_t = int (grid [0 ])
521+ spatial = int (grid [1 :].prod () // merge_length )
522+
523+ # Figure out per-chunk audio counts; fps from grid (we only have
524+ # curr_timestamp which is per-frame timestamps in seconds). Use
525+ # them directly for the chunk start times.
526+ audio_per_chunk = audio_per_chunk_per_video [v_global ]
527+ assert len (audio_per_chunk ) == grid_t , f"audio_per_chunk len { len (audio_per_chunk )} != grid_t { grid_t } "
528+
529+ # Strip surrounding <|vision_start|> / <|vision_end|> from the
530+ # template (positions pos-1 and pos+1) -- we will emit our own.
531+ expanded_encode_id .extend (encode_id [prev : pos - 1 ])
532+
533+ for t in range (grid_t ):
534+ # Per-frame timestamp (seconds) from the video metadata
535+ t_sec = curr_timestamp [t ] if t < len (curr_timestamp ) else (t * temporal_patch_size )
536+ timestamp_token_ids = self .processor .tokenizer .encode (f"<{ t_sec :.1f} seconds>" )
537+ n_audio_t = audio_per_chunk [t ]
538+ expanded_encode_id .extend (timestamp_token_ids )
539+ expanded_encode_id .append (vision_start_id )
540+ expanded_encode_id .append (audio_start_id )
541+ expanded_encode_id .extend ([self .video_token_id ] * spatial )
542+ expanded_encode_id .extend ([self .audio_token_id ] * n_audio_t )
543+ expanded_encode_id .append (audio_end_id )
544+ expanded_encode_id .append (vision_end_id )
545+
546+ prev = pos + 2 # skip past original <|vision_end|>
547+
548+ if idx == len (video_pos ) - 1 :
549+ expanded_encode_id .extend (encode_id [prev :])
550+
551+ return expanded_encode_id , len (video_pos )
552+
461553 # ------------------------------------------------------------------
462554 # Chat template
463555 # ------------------------------------------------------------------
@@ -484,7 +576,7 @@ def chat_template(self):
484576 "{% for content in message['content'] %}"
485577 "{% if 'audio' in content or 'audio_url' in content %}"
486578 "{% set audio_count.value = audio_count.value + 1 %}"
487- "<|AUDIO |>"
579+ "<|audio_pad |>"
488580 "{% elif content['type'] == 'image' or 'image' in content or 'image_url' in content %}"
489581 "{% set image_count.value = image_count.value + 1 %}"
490582 "<|vision_start|><|image_pad|><|vision_end|>"
0 commit comments