@@ -127,6 +127,8 @@ def __call__(
127127 decoder_segment_ids = None ,
128128 encoder_images : None | jnp .ndarray = None ,
129129 encoder_image_masks : None | jnp .ndarray = None ,
130+ encoder_videos : None | jnp .ndarray = None ,
131+ encoder_video_masks : None | jnp .ndarray = None ,
130132 encoder_audios : None | jnp .ndarray = None ,
131133 enable_dropout = True ,
132134 model_mode = MODEL_MODE_TRAIN ,
@@ -153,17 +155,28 @@ def __call__(
153155 f" which is always { DECODING_ACTIVE_SEQUENCE_INDICATOR } ."
154156 )
155157
156- bidirectional_mask = None
158+ bidirectional_mask_image = None
159+ bidirectional_mask_video = None
157160 image_embeddings = None
161+ video_embeddings = None
158162 audio_embeddings = None
159163 deepstack_visual_embeds = None
160164
161165 if self .config .use_multimodal and encoder_images is not None :
162166 image_embeddings , deepstack_visual_embeds = self .vision_encoder (
163167 input_images = encoder_images , deterministic = not enable_dropout
164168 )
169+ bidirectional_mask_image = mm_processor .get_bidirectional_mask_vision (
170+ self .config , decoder_input_tokens , is_video = False
171+ )
165172
166- bidirectional_mask = mm_processor .get_bidirectional_mask_vision (self .config , decoder_input_tokens )
173+ if self .config .use_multimodal and encoder_videos is not None :
174+ video_embeddings , deepstack_visual_embeds = self .vision_encoder (
175+ input_images = encoder_videos , deterministic = not enable_dropout
176+ )
177+ bidirectional_mask_video = mm_processor .get_bidirectional_mask_vision (
178+ self .config , decoder_input_tokens , is_video = True
179+ )
167180
168181 if self .config .use_multimodal and encoder_audios is not None and self .audio_encoder is not None :
169182 audio_embeddings = self .audio_encoder (input_audio = encoder_audios , deterministic = not enable_dropout )
@@ -174,13 +187,16 @@ def __call__(
174187 audio_masks = mm_processor .get_bidirectional_mask_audio (self .config , decoder_input_tokens )
175188
176189 multimodal_input = None
177- if image_embeddings is not None or audio_embeddings is not None :
190+ if image_embeddings is not None or video_embeddings is not None or audio_embeddings is not None :
178191 multimodal_input = MultimodalInput (
179192 image_embeddings = image_embeddings ,
180193 image_masks = encoder_image_masks ,
194+ video_embeddings = video_embeddings ,
195+ video_masks = encoder_video_masks ,
181196 audio_embeddings = audio_embeddings ,
182197 audio_masks = audio_masks ,
183- bidirectional_mask = bidirectional_mask ,
198+ bidirectional_mask = bidirectional_mask_image ,
199+ bidirectional_mask_video = bidirectional_mask_video ,
184200 )
185201
186202 logits , hidden_state , kv_caches = self .decoder (
@@ -425,6 +441,8 @@ def __call__(
425441 cache = None ,
426442 encoder_images : jax .Array | None = None ,
427443 encoder_image_masks : jax .Array | None = None ,
444+ encoder_videos : jax .Array | None = None ,
445+ encoder_video_masks : jax .Array | None = None ,
428446 encoder_audios : jax .Array | None = None ,
429447 enable_dropout = True ,
430448 model_mode = MODEL_MODE_TRAIN ,
@@ -466,16 +484,28 @@ def __call__(
466484 f" which is always { DECODING_ACTIVE_SEQUENCE_INDICATOR } ."
467485 )
468486
469- bidirectional_mask = None
487+ bidirectional_mask_image = None
488+ bidirectional_mask_video = None
470489 image_embeddings = None
490+ video_embeddings = None
491+ audio_embeddings = None
471492 deepstack_visual_embeds = None
472493 if self .config .use_multimodal and encoder_images is not None :
473494 image_embeddings , deepstack_visual_embeds = self .vision_encoder (
474495 input_images = encoder_images , deterministic = not enable_dropout
475496 )
476- bidirectional_mask = mm_processor .get_bidirectional_mask_vision (self .config , decoder_input_tokens )
497+ bidirectional_mask_image = mm_processor .get_bidirectional_mask_vision (
498+ self .config , decoder_input_tokens , is_video = False
499+ )
500+
501+ if self .config .use_multimodal and encoder_videos is not None :
502+ video_embeddings , deepstack_visual_embeds = self .vision_encoder (
503+ input_images = encoder_videos , deterministic = not enable_dropout
504+ )
505+ bidirectional_mask_video = mm_processor .get_bidirectional_mask_vision (
506+ self .config , decoder_input_tokens , is_video = True
507+ )
477508
478- audio_embeddings = None
479509 if self .config .use_multimodal and encoder_audios is not None and self .audio_encoder is not None :
480510 audio_embeddings = self .audio_encoder (input_audio = encoder_audios , deterministic = not enable_dropout )
481511
@@ -485,13 +515,16 @@ def __call__(
485515 audio_masks = mm_processor .get_bidirectional_mask_audio (self .config , decoder_input_tokens )
486516
487517 multimodal_input = None
488- if image_embeddings is not None or audio_embeddings is not None :
518+ if image_embeddings is not None or video_embeddings is not None or audio_embeddings is not None :
489519 multimodal_input = MultimodalInput (
490520 image_embeddings = image_embeddings ,
491521 image_masks = encoder_image_masks ,
522+ video_embeddings = video_embeddings ,
523+ video_masks = encoder_video_masks ,
492524 audio_embeddings = audio_embeddings ,
493525 audio_masks = audio_masks ,
494- bidirectional_mask = bidirectional_mask ,
526+ bidirectional_mask = bidirectional_mask_image ,
527+ bidirectional_mask_video = bidirectional_mask_video ,
495528 )
496529
497530 mutable_collections = []
0 commit comments