Skip to content

Commit 846b5f9

Browse files
committed
revert dynamic steps to simplify
1 parent 921185c commit 846b5f9

2 files changed

Lines changed: 72 additions & 72 deletions

File tree

src/diffusers/modular_pipelines/wan/encoders.py

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import FrozenDict
2323
from ...guiders import ClassifierFreeGuidance
24-
from ...utils import is_ftfy_available, logging
24+
from ...utils import is_ftfy_available, is_torchvision_available, logging
2525
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2626
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
2727
from .modular_pipeline import WanModularPipeline
@@ -31,9 +31,13 @@
3131
import PIL
3232
import numpy as np
3333

34+
3435
if is_ftfy_available():
3536
import ftfy
3637

38+
if is_torchvision_available():
39+
from torchvision import transforms
40+
3741

3842
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3943

@@ -307,83 +311,92 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
307311
return components, state
308312

309313

310-
class WanImageResizeDynamicStep(ModularPipelineBlocks):
314+
class WanImageResizeStep(ModularPipelineBlocks):
311315
model_name = "wan"
312316

313-
def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
314-
"""Create a configurable step for resizing images to the target area (height * width) while maintaining the aspect ratio.
315-
316-
This block resizes an input image and exposes the resized result under configurable
317-
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
318-
"image", "last_image")
319-
320-
Args:
321-
input_name (str, optional): Name of the image field to read from the
322-
pipeline state. Defaults to "image".
323-
output_name (str, optional): Name of the resized image field to write
324-
back to the pipeline state. Defaults to "resized_image".
325-
"""
326-
if not isinstance(input_name, str) or not isinstance(output_name, str):
327-
raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}")
328-
self._image_input_name = input_name
329-
self._resized_image_output_name = output_name
330-
super().__init__()
331-
332317
@property
333318
def description(self) -> str:
334-
return f"Image Resize step that resize the {self._image_input_name} to the target area (height * width) while maintaining the aspect ratio."
319+
return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio."
335320

336321
@property
337322
def inputs(self) -> List[InputParam]:
338323
return [
339-
InputParam(self._image_input_name, type_hint=PIL.Image.Image),
324+
InputParam("image", type_hint=PIL.Image.Image, required=True),
340325
InputParam("height", type_hint=int, default=480),
341326
InputParam("width", type_hint=int, default=832),
342327
]
343328

344329
@property
345330
def intermediate_outputs(self) -> List[OutputParam]:
346331
return [
347-
OutputParam(self._resized_image_output_name, type_hint=PIL.Image.Image, description="The resized image"),
332+
OutputParam("resized_image", type_hint=PIL.Image.Image),
348333
]
349334

350335
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
351336

352337
block_state = self.get_block_state(state)
353338
max_area = block_state.height * block_state.width
354339

355-
image = getattr(block_state, self._image_input_name)
356-
340+
image = block_state.image
357341
aspect_ratio = image.height / image.width
358342
mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial
359-
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
360-
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
361-
resized_image = image.resize((width, height))
362-
setattr(block_state, self._resized_image_output_name, resized_image)
343+
block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
344+
block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
345+
block_state.resized_image = image.resize((block_state.width, block_state.height))
363346

364347
self.set_block_state(state, block_state)
365348
return components, state
366349

367350

368-
class WanImageEncoderDynamicStep(ModularPipelineBlocks):
351+
class WanImageCropResizeStep(ModularPipelineBlocks):
369352
model_name = "wan"
370353

371-
def __init__(self, input_name: str = "resized_image", output_name: str = "image_embeds"):
372-
"""Create a configurable step for encoding images to generate image embeddings.
373354

374-
This block encodes an input image and exposes the generated embeddings under configurable
375-
input and output names. Use this when you need to wire the encoder step to different image fields (e.g.,
376-
"resized_image")
377-
"""
378-
if not isinstance(input_name, str) or not isinstance(output_name, str):
379-
raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}")
380-
self._image_input_name = input_name
381-
self._image_embeds_output_name = output_name
382-
super().__init__()
355+
@property
356+
def description(self) -> str:
357+
return "Image Resize step that resize the last_image to the same size of first frame image with center crop."
358+
359+
@property
360+
def inputs(self) -> List[InputParam]:
361+
return [
362+
InputParam("resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"),
363+
InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"),
364+
]
365+
366+
@property
367+
def intermediate_outputs(self) -> List[OutputParam]:
368+
return [
369+
OutputParam("resized_last_image", type_hint=PIL.Image.Image),
370+
]
371+
372+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
373+
374+
block_state = self.get_block_state(state)
375+
376+
height = block_state.resized_image.height
377+
width = block_state.resized_image.width
378+
image = block_state.last_image
379+
380+
# Calculate resize ratio to match first frame dimensions
381+
resize_ratio = max(width / image.width, height / image.height)
382+
383+
# Resize the image
384+
width = round(image.width * resize_ratio)
385+
height = round(image.height * resize_ratio)
386+
size = [width, height]
387+
resized_image = transforms.functional.center_crop(image, size)
388+
block_state.resized_last_image = resized_image
389+
390+
self.set_block_state(state, block_state)
391+
return components, state
392+
393+
394+
class WanImageEncoderStep(ModularPipelineBlocks):
395+
model_name = "wan"
383396

384397
@property
385398
def description(self) -> str:
386-
return f"Image Encoder step that generate {self._image_embeds_output_name} to guide the video generation"
399+
return "Image Encoder step that generate image_embeds to guide the video generation"
387400

388401
@property
389402
def expected_components(self) -> List[ComponentSpec]:
@@ -395,53 +408,40 @@ def expected_components(self) -> List[ComponentSpec]:
395408
@property
396409
def inputs(self) -> List[InputParam]:
397410
return [
398-
InputParam(self._image_input_name, type_hint=PIL.Image.Image),
411+
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
399412
]
400413

401414
@property
402415
def intermediate_outputs(self) -> List[OutputParam]:
403416
return [
404-
OutputParam(self._image_embeds_output_name, type_hint=torch.Tensor, description="The image embeddings"),
417+
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
405418
]
406419

407420

408421
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
409422
block_state = self.get_block_state(state)
410423

411424
device = components._execution_device
412-
413-
image = getattr(block_state, self._image_input_name)
425+
426+
image = block_state.resized_image
414427

415428
image_embeds = encode_image(
416429
image_processor=components.image_processor,
417430
image_encoder=components.image_encoder,
418431
image=image,
419432
device=device,
420433
)
421-
setattr(block_state, self._image_embeds_output_name, image_embeds)
434+
block_state.image_embeds = image_embeds
422435
self.set_block_state(state, block_state)
423436
return components, state
424437

425438

426-
class WanVaeImageEncoderDynamicStep(ModularPipelineBlocks):
439+
class WanVaeImageEncoderStep(ModularPipelineBlocks):
427440
model_name = "wan"
428441

429-
def __init__(self, input_name: str = "resized_image", output_name: str = "first_frame_latents"):
430-
"""Create a configurable step for encoding images to generate image latents.
431-
432-
This block encodes an input image and exposes the generated latents under configurable
433-
input and output names. Use this when you need to wire the encoder step to different image fields (e.g.,
434-
"resized_image")
435-
"""
436-
if not isinstance(input_name, str) or not isinstance(output_name, str):
437-
raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}")
438-
self._image_input_name = input_name
439-
self._image_latents_output_name = output_name
440-
super().__init__()
441-
442442
@property
443443
def description(self) -> str:
444-
return f"Vae Image Encoder step that generate {self._image_latents_output_name} to guide the video generation"
444+
return "Vae Image Encoder step that generate first_frame_latents to guide the video generation"
445445

446446
@property
447447
def expected_components(self) -> List[ComponentSpec]:
@@ -453,7 +453,7 @@ def expected_components(self) -> List[ComponentSpec]:
453453
@property
454454
def inputs(self) -> List[InputParam]:
455455
return [
456-
InputParam(self._image_input_name, type_hint=PIL.Image.Image),
456+
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
457457
InputParam("height"),
458458
InputParam("width"),
459459
InputParam("num_frames"),
@@ -463,7 +463,7 @@ def inputs(self) -> List[InputParam]:
463463
@property
464464
def intermediate_outputs(self) -> List[OutputParam]:
465465
return [
466-
OutputParam(self._image_latents_output_name, type_hint=torch.Tensor, description="The latent condition"),
466+
OutputParam("first_frame_latents", type_hint=torch.Tensor, description="The latent condition"),
467467
]
468468

469469
@staticmethod
@@ -485,7 +485,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
485485
block_state = self.get_block_state(state)
486486
self.check_inputs(components, block_state)
487487

488-
image = getattr(block_state, self._image_input_name)
488+
image = block_state.resized_image
489489

490490
device = components._execution_device
491491
dtype = torch.float32
@@ -509,6 +509,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
509509
latent_channels=components.num_channels_latents,
510510
)
511511

512-
setattr(block_state, self._image_latents_output_name, latent_condition)
512+
block_state.first_frame_latents = latent_condition
513513
self.set_block_state(state, block_state)
514514
return components, state

src/diffusers/modular_pipelines/wan/modular_blocks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from .decoders import WanImageVaeDecoderStep
2626
from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep
27-
from .encoders import WanTextEncoderStep, WanImageResizeDynamicStep, WanImageEncoderDynamicStep, WanVaeImageEncoderDynamicStep
27+
from .encoders import WanTextEncoderStep, WanImageResizeStep, WanImageEncoderStep, WanVaeImageEncoderStep
2828

2929

3030
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -57,7 +57,7 @@ def description(self):
5757
## iamge encoder
5858
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
5959
model_name = "wan"
60-
block_classes = [WanImageResizeDynamicStep(input_name="image", output_name="resized_image"), WanImageEncoderDynamicStep(input_name="resized_image", output_name="image_embeds")]
60+
block_classes = [WanImageResizeStep, WanImageEncoderStep]
6161
block_names = ["image_resize", "image_encoder"]
6262

6363
@property
@@ -69,7 +69,7 @@ def description(self):
6969
# vae encoder
7070
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
7171
model_name = "wan"
72-
block_classes = [WanImageResizeDynamicStep(input_name="image", output_name="resized_image"), WanVaeImageEncoderDynamicStep(input_name="resized_image", output_name="first_frame_latents")]
72+
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
7373
block_names = ["image_resize", "vae_image_encoder"]
7474

7575
@property
@@ -189,8 +189,8 @@ def description(self):
189189

190190
IMAGE2VIDEO_BLOCKS = InsertableDict(
191191
[
192-
("image_resize", WanImageResizeDynamicStep()),
193-
("image_encoder", WanImage2VideoImageEncoderStep()),
192+
("image_resize", WanImageResizeStep),
193+
("image_encoder", WanImage2VideoImageEncoderStep),
194194
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
195195
("input", WanTextInputStep),
196196
("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])),

0 commit comments

Comments
 (0)