@@ -227,13 +227,14 @@ def __init__(
227227 self .video_processor = HunyuanVideo15ImageProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
228228 self .target_size = self .transformer .config .target_size if getattr (self , "transformer" , None ) else 640
229229 self .vision_states_dim = self .transformer .config .image_embed_dim if getattr (self , "transformer" , None ) else 1152
230+ self .num_channels_latents = self .vae .latent_channels if hasattr (self , "vae" ) else 32
230231 # fmt: off
231- self .system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
232- 1. The main content and theme of the video. \
233- 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
234- 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
235- 4. background environment, light, style and atmosphere. \
236- 5. camera angles, movements, and transitions used in the video."
232+ self .system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
233+ 1. The main content and theme of the video. \
234+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
235+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
236+ 4. background environment, light, style and atmosphere. \
237+ 5. camera angles, movements, and transitions used in the video."
237238 # fmt: on
238239 self .prompt_template_encode_start_idx = 108
239240 self .tokenizer_max_length = 1000
@@ -253,11 +254,11 @@ def _get_mllm_prompt_embeds(
253254 num_hidden_layers_to_skip : int = 2 ,
254255 # fmt: off
255256 system_message : str = "You are a helpful assistant. Describe the video by detailing the following aspects: \
256- 1. The main content and theme of the video. \
257- 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
258- 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
259- 4. background environment, light, style and atmosphere. \
260- 5. camera angles, movements, and transitions used in the video." ,
257+ 1. The main content and theme of the video. \
258+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
259+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
260+ 4. background environment, light, style and atmosphere. \
261+ 5. camera angles, movements, and transitions used in the video." ,
261262 # fmt: on
262263 crop_start : int = 108 ,
263264 ) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -286,12 +287,13 @@ def _get_mllm_prompt_embeds(
286287 attention_mask = prompt_attention_mask ,
287288 output_hidden_states = True ,
288289 ).hidden_states [- (num_hidden_layers_to_skip + 1 )]
289- prompt_embeds = prompt_embeds .to (dtype = dtype )
290290
291291 if crop_start is not None and crop_start > 0 :
292292 prompt_embeds = prompt_embeds [:, crop_start :]
293293 prompt_attention_mask = prompt_attention_mask [:, crop_start :]
294294
295+ prompt_embeds = prompt_embeds .to (dtype = dtype )
296+
295297 return prompt_embeds , prompt_attention_mask
296298
297299
@@ -578,7 +580,7 @@ def __call__(
578580 negative_prompt : Union [str , List [str ]] = None ,
579581 height : Optional [int ] = None ,
580582 width : Optional [int ] = None ,
581- num_frames : int = 129 ,
583+ num_frames : int = 121 ,
582584 num_inference_steps : int = 50 ,
583585 sigmas : List [float ] = None ,
584586 num_videos_per_prompt : Optional [int ] = 1 ,
@@ -752,10 +754,9 @@ def __call__(
752754 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas )
753755
754756 # 5. Prepare latent variables
755- num_channels_latents = self .transformer .config .in_channels
756757 latents = self .prepare_latents (
757758 batch_size * num_videos_per_prompt ,
758- num_channels_latents ,
759+ self . num_channels_latents ,
759760 height ,
760761 width ,
761762 num_frames ,
@@ -877,7 +878,7 @@ def __call__(
877878
878879 if not output_type == "latent" :
879880 latents = latents .to (self .vae .dtype ) / self .vae .config .scaling_factor
880- video = self .vae .decode (latents , return_dict = False , generator = generator )[0 ]
881+ video = self .vae .decode (latents , return_dict = False )[0 ]
881882 video = self .video_processor .postprocess_video (video , output_type = output_type )
882883 else :
883884 video = latents
0 commit comments