1717import numpy as np
1818import torch
1919
20- from ...configuration_utils import FrozenDict
2120from ...models import HunyuanVideo15Transformer3DModel
2221from ...schedulers import FlowMatchEulerDiscreteScheduler
2322from ...utils import logging
3332# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
3433def retrieve_timesteps (
3534 scheduler ,
36- num_inference_steps = None ,
37- device = None ,
38- timesteps = None ,
39- sigmas = None ,
35+ num_inference_steps : int | None = None ,
36+ device : str | torch . device | None = None ,
37+ timesteps : list [ int ] | None = None ,
38+ sigmas : list [ float ] | None = None ,
4039 ** kwargs ,
4140):
41+ r"""
42+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
43+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
44+
45+ Args:
46+ scheduler (`SchedulerMixin`):
47+ The scheduler to get timesteps from.
48+ num_inference_steps (`int`):
49+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
50+ must be `None`.
51+ device (`str` or `torch.device`, *optional*):
52+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
53+ timesteps (`list[int]`, *optional*):
54+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
55+ `num_inference_steps` and `sigmas` must be `None`.
56+ sigmas (`list[float]`, *optional*):
57+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
58+ `num_inference_steps` and `timesteps` must be `None`.
59+
60+ Returns:
61+ `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
62+ second element is the number of inference steps.
63+ """
4264 if timesteps is not None and sigmas is not None :
43- raise ValueError ("Only one of `timesteps` or `sigmas` can be passed." )
65+ raise ValueError ("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values " )
4466 if timesteps is not None :
67+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
68+ if not accepts_timesteps :
69+ raise ValueError (
70+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
71+ f" timestep schedules. Please check whether you are using the correct scheduler."
72+ )
4573 scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
4674 timesteps = scheduler .timesteps
4775 num_inference_steps = len (timesteps )
4876 elif sigmas is not None :
4977 accept_sigmas = "sigmas" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
5078 if not accept_sigmas :
5179 raise ValueError (
52- f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom sigmas."
80+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
81+ f" sigmas schedules. Please check whether you are using the correct scheduler."
5382 )
5483 scheduler .set_timesteps (sigmas = sigmas , device = device , ** kwargs )
5584 timesteps = scheduler .timesteps
@@ -65,13 +94,7 @@ class HunyuanVideo15TextInputStep(ModularPipelineBlocks):
6594
6695 @property
6796 def description (self ) -> str :
68- return "Input processing step that determines batch_size and dtype"
69-
70- @property
71- def expected_components (self ) -> list [ComponentSpec ]:
72- return [
73- ComponentSpec ("transformer" , HunyuanVideo15Transformer3DModel ),
74- ]
97+ return "Input processing step that determines batch_size"
7598
7699 @property
77100 def inputs (self ) -> list [InputParam ]:
@@ -85,14 +108,12 @@ def inputs(self) -> list[InputParam]:
85108 def intermediate_outputs (self ) -> list [OutputParam ]:
86109 return [
87110 OutputParam ("batch_size" , type_hint = int ),
88- OutputParam ("dtype" , type_hint = torch .dtype ),
89111 ]
90112
91113 @torch .no_grad ()
92114 def __call__ (self , components : HunyuanVideo15ModularPipeline , state : PipelineState ) -> PipelineState :
93115 block_state = self .get_block_state (state )
94116 block_state .batch_size = getattr (block_state , "batch_size" , None ) or block_state .prompt_embeds .shape [0 ]
95- block_state .dtype = components .transformer .dtype
96117 self .set_block_state (state , block_state )
97118 return components , state
98119
@@ -122,7 +143,6 @@ def intermediate_outputs(self) -> list[OutputParam]:
122143 OutputParam ("num_inference_steps" , type_hint = int ),
123144 ]
124145
125- # Copied from pipeline_hunyuan_video1_5.py line 702-704
126146 @torch .no_grad ()
127147 def __call__ (self , components : HunyuanVideo15ModularPipeline , state : PipelineState ) -> PipelineState :
128148 block_state = self .get_block_state (state )
@@ -147,6 +167,10 @@ class HunyuanVideo15PrepareLatentsStep(ModularPipelineBlocks):
147167 def description (self ) -> str :
148168 return "Prepare latents, conditioning latents, mask, and image_embeds for T2V"
149169
170+ @property
171+ def expected_components (self ) -> list [ComponentSpec ]:
172+ return [ComponentSpec ("transformer" , HunyuanVideo15Transformer3DModel )]
173+
150174 @property
151175 def inputs (self ) -> list [InputParam ]:
152176 return [
@@ -157,24 +181,22 @@ def inputs(self) -> list[InputParam]:
157181 InputParam .template ("num_images_per_prompt" , name = "num_videos_per_prompt" ),
158182 InputParam .template ("generator" ),
159183 InputParam .template ("batch_size" , required = True , default = None ),
160- InputParam .template ("dtype" , default = None ),
161184 ]
162185
163186 @property
164187 def intermediate_outputs (self ) -> list [OutputParam ]:
165188 return [
166- OutputParam . template ( " latents" ),
189+ OutputParam ( "latents" , type_hint = torch . Tensor , description = "Pure noise latents" ),
167190 OutputParam ("cond_latents_concat" , type_hint = torch .Tensor ),
168191 OutputParam ("mask_concat" , type_hint = torch .Tensor ),
169192 OutputParam ("image_embeds" , type_hint = torch .Tensor ),
170193 ]
171194
172- # Copied from pipeline_hunyuan_video1_5.py lines 652-655, 477-524, 706-725 with self->components
173195 @torch .no_grad ()
174196 def __call__ (self , components : HunyuanVideo15ModularPipeline , state : PipelineState ) -> PipelineState :
175197 block_state = self .get_block_state (state )
176198 device = components ._execution_device
177- dtype = block_state .dtype
199+ dtype = components . transformer .dtype
178200
179201 height = block_state .height
180202 width = block_state .width
@@ -186,7 +208,6 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
186208 batch_size = block_state .batch_size * block_state .num_videos_per_prompt
187209 num_frames = block_state .num_frames
188210
189- # Copied from HunyuanVideo15Pipeline.prepare_latents with self->components
190211 latents = block_state .latents
191212 if latents is not None :
192213 latents = latents .to (device = device , dtype = dtype )
@@ -207,12 +228,10 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
207228
208229 block_state .latents = latents
209230
210- # Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask with self->components
211231 b , c , f , h , w = latents .shape
212232 block_state .cond_latents_concat = torch .zeros (b , c , f , h , w , dtype = dtype , device = device )
213233 block_state .mask_concat = torch .zeros (b , 1 , f , h , w , dtype = dtype , device = device )
214234
215- # T2V: zero image_embeds
216235 block_state .image_embeds = torch .zeros (
217236 block_state .batch_size ,
218237 components .vision_num_semantic_tokens ,
@@ -225,125 +244,62 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta
225244 return components , state
226245
227246
228- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
229- def retrieve_latents (encoder_output , generator = None , sample_mode = "sample" ):
230- if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
231- return encoder_output .latent_dist .sample (generator )
232- elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
233- return encoder_output .latent_dist .mode ()
234- elif hasattr (encoder_output , "latents" ):
235- return encoder_output .latents
236- raise AttributeError ("Could not access latents of provided encoder_output" )
237-
238-
239247class HunyuanVideo15Image2VideoPrepareLatentsStep (ModularPipelineBlocks ):
240248 model_name = "hunyuan-video-1.5"
241249
242250 @property
243251 def description (self ) -> str :
244- return "Prepare latents, conditioning latents, mask, and image_embeds for I2V"
252+ return (
253+ "Prepare I2V conditioning from image_latents and image_embeds. "
254+ "Expects pure noise `latents` from HunyuanVideo15PrepareLatentsStep. "
255+ "Builds cond_latents_concat and mask_concat for the denoiser."
256+ )
245257
246258 @property
247259 def expected_components (self ) -> list [ComponentSpec ]:
248- from transformers import SiglipImageProcessor , SiglipVisionModel
249-
250- from ...models import AutoencoderKLHunyuanVideo15
251- from ...pipelines .hunyuan_video1_5 .image_processor import HunyuanVideo15ImageProcessor
252-
253- return [
254- ComponentSpec ("vae" , AutoencoderKLHunyuanVideo15 ),
255- ComponentSpec (
256- "video_processor" ,
257- HunyuanVideo15ImageProcessor ,
258- config = FrozenDict ({"vae_scale_factor" : 16 }),
259- default_creation_method = "from_config" ,
260- ),
261- ComponentSpec ("image_encoder" , SiglipVisionModel ),
262- ComponentSpec ("feature_extractor" , SiglipImageProcessor ),
263- ]
260+ return [ComponentSpec ("transformer" , HunyuanVideo15Transformer3DModel )]
264261
265262 @property
266263 def inputs (self ) -> list [InputParam ]:
267264 return [
268- InputParam . template ( "image" ),
269- InputParam ("num_frames " , type_hint = int , default = 121 ),
270- InputParam .template ("latents" ),
265+ InputParam ( "image_latents" , type_hint = torch . Tensor , required = True ),
266+ InputParam ("image_embeds " , type_hint = torch . Tensor , required = True ),
267+ InputParam .template ("latents" , required = True ),
271268 InputParam .template ("num_images_per_prompt" , name = "num_videos_per_prompt" ),
272- InputParam .template ("generator" ),
273269 InputParam .template ("batch_size" , required = True , default = None ),
274- InputParam .template ("dtype" , default = None ),
275270 ]
276271
277272 @property
278273 def intermediate_outputs (self ) -> list [OutputParam ]:
279274 return [
280- OutputParam .template ("latents" ),
281275 OutputParam ("cond_latents_concat" , type_hint = torch .Tensor ),
282276 OutputParam ("mask_concat" , type_hint = torch .Tensor ),
283277 OutputParam ("image_embeds" , type_hint = torch .Tensor ),
284278 ]
285279
286- # Copied from pipeline_hunyuan_video1_5_image2video.py lines 756-839 with self->components
287280 @torch .no_grad ()
288281 def __call__ (self , components : HunyuanVideo15ModularPipeline , state : PipelineState ) -> PipelineState :
289282 block_state = self .get_block_state (state )
290283 device = components ._execution_device
291- dtype = block_state .dtype
284+ dtype = components . transformer .dtype
292285
293- image = block_state .image
294286 batch_size = block_state .batch_size * block_state .num_videos_per_prompt
295- num_frames = block_state .num_frames
296287
297- # Resize/crop image to target resolution (line 756-759)
298- height , width = components .video_processor .calculate_default_height_width (
299- height = image .size [1 ], width = image .size [0 ], target_size = components .target_size
300- )
301- image = components .video_processor .resize (image , height = height , width = width , resize_mode = "crop" )
302-
303- # Encode image with Siglip (lines 776-781)
304- image_encoder_dtype = next (components .image_encoder .parameters ()).dtype
305- image_inputs = components .feature_extractor .preprocess (
306- images = image , do_resize = True , return_tensors = "pt" , do_convert_rgb = True
307- )
308- image_inputs = image_inputs .to (device = device , dtype = image_encoder_dtype )
309- image_embeds = components .image_encoder (** image_inputs ).last_hidden_state
310- image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
311- block_state .image_embeds = image_embeds .to (device = device , dtype = dtype )
288+ b , c , f , h , w = block_state .latents .shape
312289
313- # Prepare latents (lines 818-829)
314- latents = block_state .latents
315- if latents is not None :
316- latents = latents .to (device = device , dtype = dtype )
317- else :
318- shape = (
319- batch_size ,
320- components .num_channels_latents ,
321- (num_frames - 1 ) // components .vae_scale_factor_temporal + 1 ,
322- int (height ) // components .vae_scale_factor_spatial ,
323- int (width ) // components .vae_scale_factor_spatial ,
324- )
325- latents = randn_tensor (shape , generator = block_state .generator , device = device , dtype = dtype )
326- block_state .latents = latents
327-
328- # Prepare cond latents and mask (lines 594-632, 831-839)
329- b , c , f , h , w = latents .shape
330-
331- # Copied from _get_image_latents (lines 375-388) with self->components
332- vae_dtype = components .vae .dtype
333- image_tensor = components .video_processor .preprocess (
334- image , height = h * components .vae_scale_factor_spatial , width = w * components .vae_scale_factor_spatial
335- ).to (device , dtype = vae_dtype )
336- image_tensor = image_tensor .unsqueeze (2 )
337- image_latents = retrieve_latents (components .vae .encode (image_tensor ), sample_mode = "argmax" )
338- image_latents = image_latents * components .vae .config .scaling_factor
339-
340- latent_condition = image_latents .repeat (batch_size , 1 , f , 1 , 1 )
290+ latent_condition = block_state .image_latents .to (device = device , dtype = dtype )
291+ latent_condition = latent_condition .repeat (batch_size , 1 , f , 1 , 1 )
341292 latent_condition [:, :, 1 :, :, :] = 0
342- block_state .cond_latents_concat = latent_condition . to ( device = device , dtype = dtype )
293+ block_state .cond_latents_concat = latent_condition
343294
344295 latent_mask = torch .zeros (b , 1 , f , h , w , dtype = dtype , device = device )
345296 latent_mask [:, :, 0 , :, :] = 1.0
346297 block_state .mask_concat = latent_mask
347298
299+ image_embeds = block_state .image_embeds .to (device = device , dtype = dtype )
300+ if image_embeds .shape [0 ] == 1 and batch_size > 1 :
301+ image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
302+ block_state .image_embeds = image_embeds
303+
348304 self .set_block_state (state , block_state )
349305 return components , state
0 commit comments