2121
2222from ...configuration_utils import FrozenDict
2323from ...guiders import ClassifierFreeGuidance
24- from ...utils import is_ftfy_available , logging
24+ from ...utils import is_ftfy_available , is_torchvision_available , logging
2525from ..modular_pipeline import ModularPipelineBlocks , PipelineState
2626from ..modular_pipeline_utils import ComponentSpec , ConfigSpec , InputParam , OutputParam
2727from .modular_pipeline import WanModularPipeline
3131import PIL
3232import numpy as np
3333
34+
3435if is_ftfy_available ():
3536 import ftfy
3637
38+ if is_torchvision_available ():
39+ from torchvision import transforms
40+
3741
3842logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3943
@@ -307,83 +311,92 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
307311 return components , state
308312
309313
310- class WanImageResizeDynamicStep (ModularPipelineBlocks ):
314+ class WanImageResizeStep (ModularPipelineBlocks ):
311315 model_name = "wan"
312316
313- def __init__ (self , input_name : str = "image" , output_name : str = "resized_image" ):
314- """Create a configurable step for resizing images to the target area (height * width) while maintaining the aspect ratio.
315-
316- This block resizes an input image and exposes the resized result under configurable
317- input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
318- "image", "last_image")
319-
320- Args:
321- input_name (str, optional): Name of the image field to read from the
322- pipeline state. Defaults to "image".
323- output_name (str, optional): Name of the resized image field to write
324- back to the pipeline state. Defaults to "resized_image".
325- """
326- if not isinstance (input_name , str ) or not isinstance (output_name , str ):
327- raise ValueError (f"input_name and output_name must be strings but are { type (input_name )} and { type (output_name )} " )
328- self ._image_input_name = input_name
329- self ._resized_image_output_name = output_name
330- super ().__init__ ()
331-
332317 @property
333318 def description (self ) -> str :
334- return f "Image Resize step that resize the { self . _image_input_name } to the target area (height * width) while maintaining the aspect ratio."
319+ return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio."
335320
336321 @property
337322 def inputs (self ) -> List [InputParam ]:
338323 return [
339- InputParam (self . _image_input_name , type_hint = PIL .Image .Image ),
324+ InputParam ("image" , type_hint = PIL .Image .Image , required = True ),
340325 InputParam ("height" , type_hint = int , default = 480 ),
341326 InputParam ("width" , type_hint = int , default = 832 ),
342327 ]
343328
344329 @property
345330 def intermediate_outputs (self ) -> List [OutputParam ]:
346331 return [
347- OutputParam (self . _resized_image_output_name , type_hint = PIL .Image .Image , description = "The resized image" ),
332+ OutputParam ("resized_image" , type_hint = PIL .Image .Image ),
348333 ]
349334
350335 def __call__ (self , components : WanModularPipeline , state : PipelineState ) -> PipelineState :
351336
352337 block_state = self .get_block_state (state )
353338 max_area = block_state .height * block_state .width
354339
355- image = getattr (block_state , self ._image_input_name )
356-
340+ image = block_state .image
357341 aspect_ratio = image .height / image .width
358342 mod_value = components .vae_scale_factor_spatial * components .patch_size_spatial
359- height = round (np .sqrt (max_area * aspect_ratio )) // mod_value * mod_value
360- width = round (np .sqrt (max_area / aspect_ratio )) // mod_value * mod_value
361- resized_image = image .resize ((width , height ))
362- setattr (block_state , self ._resized_image_output_name , resized_image )
343+ block_state .height = round (np .sqrt (max_area * aspect_ratio )) // mod_value * mod_value
344+ block_state .width = round (np .sqrt (max_area / aspect_ratio )) // mod_value * mod_value
345+ block_state .resized_image = image .resize ((block_state .width , block_state .height ))
363346
364347 self .set_block_state (state , block_state )
365348 return components , state
366349
367350
368- class WanImageEncoderDynamicStep (ModularPipelineBlocks ):
351+ class WanImageCropResizeStep (ModularPipelineBlocks ):
369352 model_name = "wan"
370353
371- def __init__ (self , input_name : str = "resized_image" , output_name : str = "image_embeds" ):
372- """Create a configurable step for encoding images to generate image embeddings.
373354
374- This block encodes an input image and exposes the generated embeddings under configurable
375- input and output names. Use this when you need to wire the encoder step to different image fields (e.g.,
376- "resized_image")
377- """
378- if not isinstance (input_name , str ) or not isinstance (output_name , str ):
379- raise ValueError (f"input_name and output_name must be strings but are { type (input_name )} and { type (output_name )} " )
380- self ._image_input_name = input_name
381- self ._image_embeds_output_name = output_name
382- super ().__init__ ()
355+ @property
356+ def description (self ) -> str :
357+ return "Image Resize step that resize the last_image to the same size of first frame image with center crop."
358+
359+ @property
360+ def inputs (self ) -> List [InputParam ]:
361+ return [
362+ InputParam ("resized_image" , type_hint = PIL .Image .Image , required = True , description = "The resized first frame image" ),
363+ InputParam ("last_image" , type_hint = PIL .Image .Image , required = True , description = "The last frameimage" ),
364+ ]
365+
366+ @property
367+ def intermediate_outputs (self ) -> List [OutputParam ]:
368+ return [
369+ OutputParam ("resized_last_image" , type_hint = PIL .Image .Image ),
370+ ]
371+
372+ def __call__ (self , components : WanModularPipeline , state : PipelineState ) -> PipelineState :
373+
374+ block_state = self .get_block_state (state )
375+
376+ height = block_state .resized_image .height
377+ width = block_state .resized_image .width
378+ image = block_state .last_image
379+
380+ # Calculate resize ratio to match first frame dimensions
381+ resize_ratio = max (width / image .width , height / image .height )
382+
383+ # Resize the image
384+ width = round (image .width * resize_ratio )
385+ height = round (image .height * resize_ratio )
386+ size = [width , height ]
387+ resized_image = transforms .functional .center_crop (image , size )
388+ block_state .resized_last_image = resized_image
389+
390+ self .set_block_state (state , block_state )
391+ return components , state
392+
393+
394+ class WanImageEncoderStep (ModularPipelineBlocks ):
395+ model_name = "wan"
383396
384397 @property
385398 def description (self ) -> str :
386- return f "Image Encoder step that generate { self . _image_embeds_output_name } to guide the video generation"
399+ return "Image Encoder step that generate image_embeds to guide the video generation"
387400
388401 @property
389402 def expected_components (self ) -> List [ComponentSpec ]:
@@ -395,53 +408,40 @@ def expected_components(self) -> List[ComponentSpec]:
395408 @property
396409 def inputs (self ) -> List [InputParam ]:
397410 return [
398- InputParam (self . _image_input_name , type_hint = PIL .Image .Image ),
411+ InputParam ("resized_image" , type_hint = PIL .Image .Image , required = True ),
399412 ]
400413
401414 @property
402415 def intermediate_outputs (self ) -> List [OutputParam ]:
403416 return [
404- OutputParam (self . _image_embeds_output_name , type_hint = torch .Tensor , description = "The image embeddings" ),
417+ OutputParam ("image_embeds" , type_hint = torch .Tensor , description = "The image embeddings" ),
405418 ]
406419
407420
408421 def __call__ (self , components : WanModularPipeline , state : PipelineState ) -> PipelineState :
409422 block_state = self .get_block_state (state )
410423
411424 device = components ._execution_device
412-
413- image = getattr ( block_state , self . _image_input_name )
425+
426+ image = block_state . resized_image
414427
415428 image_embeds = encode_image (
416429 image_processor = components .image_processor ,
417430 image_encoder = components .image_encoder ,
418431 image = image ,
419432 device = device ,
420433 )
421- setattr ( block_state , self . _image_embeds_output_name , image_embeds )
434+ block_state . image_embeds = image_embeds
422435 self .set_block_state (state , block_state )
423436 return components , state
424437
425438
426- class WanVaeImageEncoderDynamicStep (ModularPipelineBlocks ):
439+ class WanVaeImageEncoderStep (ModularPipelineBlocks ):
427440 model_name = "wan"
428441
429- def __init__ (self , input_name : str = "resized_image" , output_name : str = "first_frame_latents" ):
430- """Create a configurable step for encoding images to generate image latents.
431-
432- This block encodes an input image and exposes the generated latents under configurable
433- input and output names. Use this when you need to wire the encoder step to different image fields (e.g.,
434- "resized_image")
435- """
436- if not isinstance (input_name , str ) or not isinstance (output_name , str ):
437- raise ValueError (f"input_name and output_name must be strings but are { type (input_name )} and { type (output_name )} " )
438- self ._image_input_name = input_name
439- self ._image_latents_output_name = output_name
440- super ().__init__ ()
441-
442442 @property
443443 def description (self ) -> str :
444- return f "Vae Image Encoder step that generate { self . _image_latents_output_name } to guide the video generation"
444+ return "Vae Image Encoder step that generate first_frame_latents to guide the video generation"
445445
446446 @property
447447 def expected_components (self ) -> List [ComponentSpec ]:
@@ -453,7 +453,7 @@ def expected_components(self) -> List[ComponentSpec]:
453453 @property
454454 def inputs (self ) -> List [InputParam ]:
455455 return [
456- InputParam (self . _image_input_name , type_hint = PIL .Image .Image ),
456+ InputParam ("resized_image" , type_hint = PIL .Image .Image , required = True ),
457457 InputParam ("height" ),
458458 InputParam ("width" ),
459459 InputParam ("num_frames" ),
@@ -463,7 +463,7 @@ def inputs(self) -> List[InputParam]:
463463 @property
464464 def intermediate_outputs (self ) -> List [OutputParam ]:
465465 return [
466- OutputParam (self . _image_latents_output_name , type_hint = torch .Tensor , description = "The latent condition" ),
466+ OutputParam ("first_frame_latents" , type_hint = torch .Tensor , description = "The latent condition" ),
467467 ]
468468
469469 @staticmethod
@@ -485,7 +485,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
485485 block_state = self .get_block_state (state )
486486 self .check_inputs (components , block_state )
487487
488- image = getattr ( block_state , self . _image_input_name )
488+ image = block_state . resized_image
489489
490490 device = components ._execution_device
491491 dtype = torch .float32
@@ -509,6 +509,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
509509 latent_channels = components .num_channels_latents ,
510510 )
511511
512- setattr ( block_state , self . _image_latents_output_name , latent_condition )
512+ block_state . first_frame_latents = latent_condition
513513 self .set_block_state (state , block_state )
514514 return components , state
0 commit comments