Skip to content

Commit fd705bd

Browse files
yiyixuxuyiyi@huggingface.coasomoza
authored
[Modular] refactor Wan: modular pipelines by task etc (#13063)
* initil * fix init_pipeline etc * style * copies * fix copies * upup more * fix test * add output type (#13091) --------- Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal> Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
1 parent 09dca38 commit fd705bd

19 files changed

+869
-727
lines changed

src/diffusers/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@
417417
"Flux2AutoBlocks",
418418
"Flux2KleinAutoBlocks",
419419
"Flux2KleinBaseAutoBlocks",
420+
"Flux2KleinBaseModularPipeline",
420421
"Flux2KleinModularPipeline",
421422
"Flux2ModularPipeline",
422423
"FluxAutoBlocks",
@@ -433,8 +434,13 @@
433434
"QwenImageModularPipeline",
434435
"StableDiffusionXLAutoBlocks",
435436
"StableDiffusionXLModularPipeline",
436-
"Wan22AutoBlocks",
437-
"WanAutoBlocks",
437+
"Wan22Blocks",
438+
"Wan22Image2VideoBlocks",
439+
"Wan22Image2VideoModularPipeline",
440+
"Wan22ModularPipeline",
441+
"WanBlocks",
442+
"WanImage2VideoAutoBlocks",
443+
"WanImage2VideoModularPipeline",
438444
"WanModularPipeline",
439445
"ZImageAutoBlocks",
440446
"ZImageModularPipeline",
@@ -1156,6 +1162,7 @@
11561162
Flux2AutoBlocks,
11571163
Flux2KleinAutoBlocks,
11581164
Flux2KleinBaseAutoBlocks,
1165+
Flux2KleinBaseModularPipeline,
11591166
Flux2KleinModularPipeline,
11601167
Flux2ModularPipeline,
11611168
FluxAutoBlocks,
@@ -1172,8 +1179,13 @@
11721179
QwenImageModularPipeline,
11731180
StableDiffusionXLAutoBlocks,
11741181
StableDiffusionXLModularPipeline,
1175-
Wan22AutoBlocks,
1176-
WanAutoBlocks,
1182+
Wan22Blocks,
1183+
Wan22Image2VideoBlocks,
1184+
Wan22Image2VideoModularPipeline,
1185+
Wan22ModularPipeline,
1186+
WanBlocks,
1187+
WanImage2VideoAutoBlocks,
1188+
WanImage2VideoModularPipeline,
11771189
WanModularPipeline,
11781190
ZImageAutoBlocks,
11791191
ZImageModularPipeline,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,16 @@
4545
"InsertableDict",
4646
]
4747
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
48-
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
48+
_import_structure["wan"] = [
49+
"WanBlocks",
50+
"Wan22Blocks",
51+
"WanImage2VideoAutoBlocks",
52+
"Wan22Image2VideoBlocks",
53+
"WanModularPipeline",
54+
"Wan22ModularPipeline",
55+
"WanImage2VideoModularPipeline",
56+
"Wan22Image2VideoModularPipeline",
57+
]
4958
_import_structure["flux"] = [
5059
"FluxAutoBlocks",
5160
"FluxModularPipeline",
@@ -58,6 +67,7 @@
5867
"Flux2KleinBaseAutoBlocks",
5968
"Flux2ModularPipeline",
6069
"Flux2KleinModularPipeline",
70+
"Flux2KleinBaseModularPipeline",
6171
]
6272
_import_structure["qwenimage"] = [
6373
"QwenImageAutoBlocks",
@@ -88,6 +98,7 @@
8898
Flux2AutoBlocks,
8999
Flux2KleinAutoBlocks,
90100
Flux2KleinBaseAutoBlocks,
101+
Flux2KleinBaseModularPipeline,
91102
Flux2KleinModularPipeline,
92103
Flux2ModularPipeline,
93104
)
@@ -112,7 +123,16 @@
112123
QwenImageModularPipeline,
113124
)
114125
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
115-
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
126+
from .wan import (
127+
Wan22Blocks,
128+
Wan22Image2VideoBlocks,
129+
Wan22Image2VideoModularPipeline,
130+
Wan22ModularPipeline,
131+
WanBlocks,
132+
WanImage2VideoAutoBlocks,
133+
WanImage2VideoModularPipeline,
134+
WanModularPipeline,
135+
)
116136
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
117137
else:
118138
import sys

src/diffusers/modular_pipelines/flux2/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@
5555
"Flux2VaeEncoderSequentialStep",
5656
]
5757
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
58-
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
58+
_import_structure["modular_pipeline"] = [
59+
"Flux2ModularPipeline",
60+
"Flux2KleinModularPipeline",
61+
"Flux2KleinBaseModularPipeline",
62+
]
5963

6064
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
6165
try:
@@ -101,7 +105,7 @@
101105
Flux2KleinAutoBlocks,
102106
Flux2KleinBaseAutoBlocks,
103107
)
104-
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
108+
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
105109
else:
106110
import sys
107111

src/diffusers/modular_pipelines/flux2/modular_pipeline.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Any, Dict, Optional
17-
1816
from ...loaders import Flux2LoraLoaderMixin
1917
from ...utils import logging
2018
from ..modular_pipeline import ModularPipeline
@@ -59,46 +57,35 @@ def num_channels_latents(self):
5957
return num_channels_latents
6058

6159

62-
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
60+
class Flux2KleinModularPipeline(Flux2ModularPipeline):
6361
"""
64-
A ModularPipeline for Flux2-Klein.
62+
A ModularPipeline for Flux2-Klein (distilled model).
6563
6664
> [!WARNING] > This is an experimental feature and is likely to change in the future.
6765
"""
6866

69-
default_blocks_name = "Flux2KleinBaseAutoBlocks"
70-
71-
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
72-
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
73-
return "Flux2KleinAutoBlocks"
74-
else:
75-
return "Flux2KleinBaseAutoBlocks"
67+
default_blocks_name = "Flux2KleinAutoBlocks"
7668

7769
@property
78-
def default_height(self):
79-
return self.default_sample_size * self.vae_scale_factor
70+
def requires_unconditional_embeds(self):
71+
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
72+
return False
8073

81-
@property
82-
def default_width(self):
83-
return self.default_sample_size * self.vae_scale_factor
74+
requires_unconditional_embeds = False
75+
if hasattr(self, "guider") and self.guider is not None:
76+
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
8477

85-
@property
86-
def default_sample_size(self):
87-
return 128
78+
return requires_unconditional_embeds
8879

89-
@property
90-
def vae_scale_factor(self):
91-
vae_scale_factor = 8
92-
if getattr(self, "vae", None) is not None:
93-
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
94-
return vae_scale_factor
9580

96-
@property
97-
def num_channels_latents(self):
98-
num_channels_latents = 32
99-
if getattr(self, "transformer", None):
100-
num_channels_latents = self.transformer.config.in_channels // 4
101-
return num_channels_latents
81+
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
82+
"""
83+
A ModularPipeline for Flux2-Klein (base model).
84+
85+
> [!WARNING] > This is an experimental feature and is likely to change in the future.
86+
"""
87+
88+
default_blocks_name = "Flux2KleinBaseAutoBlocks"
10289

10390
@property
10491
def requires_unconditional_embeds(self):

src/diffusers/modular_pipelines/mellon_node_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def _name_to_label(name: str) -> str:
156156
"display": "slider",
157157
"required_block_params": ["layers"],
158158
},
159+
"output_type": {
160+
"label": "Output Type",
161+
"type": "dropdown",
162+
"default": "np",
163+
"options": ["np", "pil", "pt"],
164+
},
159165
# ControlNet
160166
"controlnet_conditioning_scale": {
161167
"label": "Controlnet Conditioning Scale",

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,61 @@
5454

5555

5656
# map regular pipeline to modular pipeline class name
57+
58+
59+
def _create_default_map_fn(pipeline_class_name: str):
60+
"""Create a mapping function that always returns the same pipeline class."""
61+
62+
def _map_fn(config_dict=None):
63+
return pipeline_class_name
64+
65+
return _map_fn
66+
67+
68+
def _flux2_klein_map_fn(config_dict=None):
69+
if config_dict is None:
70+
return "Flux2KleinModularPipeline"
71+
72+
if "is_distilled" in config_dict and config_dict["is_distilled"]:
73+
return "Flux2KleinModularPipeline"
74+
else:
75+
return "Flux2KleinBaseModularPipeline"
76+
77+
78+
def _wan_map_fn(config_dict=None):
79+
if config_dict is None:
80+
return "WanModularPipeline"
81+
82+
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
83+
return "Wan22ModularPipeline"
84+
else:
85+
return "WanModularPipeline"
86+
87+
88+
def _wan_i2v_map_fn(config_dict=None):
89+
if config_dict is None:
90+
return "WanImage2VideoModularPipeline"
91+
92+
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
93+
return "Wan22Image2VideoModularPipeline"
94+
else:
95+
return "WanImage2VideoModularPipeline"
96+
97+
5798
MODULAR_PIPELINE_MAPPING = OrderedDict(
5899
[
59-
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
60-
("wan", "WanModularPipeline"),
61-
("flux", "FluxModularPipeline"),
62-
("flux-kontext", "FluxKontextModularPipeline"),
63-
("flux2", "Flux2ModularPipeline"),
64-
("flux2-klein", "Flux2KleinModularPipeline"),
65-
("qwenimage", "QwenImageModularPipeline"),
66-
("qwenimage-edit", "QwenImageEditModularPipeline"),
67-
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
68-
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
69-
("z-image", "ZImageModularPipeline"),
100+
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
101+
("wan", _wan_map_fn),
102+
("wan-i2v", _wan_i2v_map_fn),
103+
("flux", _create_default_map_fn("FluxModularPipeline")),
104+
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
105+
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
106+
("flux2-klein", _flux2_klein_map_fn),
107+
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
108+
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
109+
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
110+
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
111+
("z-image", _create_default_map_fn("ZImageModularPipeline")),
70112
]
71113
)
72114

@@ -368,7 +410,8 @@ def init_pipeline(
368410
"""
369411
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
370412
"""
371-
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
413+
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
414+
pipeline_class_name = map_fn()
372415
diffusers_module = importlib.import_module("diffusers")
373416
pipeline_class = getattr(diffusers_module, pipeline_class_name)
374417

@@ -1547,7 +1590,7 @@ def __init__(
15471590
if modular_config_dict is not None:
15481591
blocks_class_name = modular_config_dict.get("_blocks_class_name")
15491592
else:
1550-
blocks_class_name = self.get_default_blocks_name(config_dict)
1593+
blocks_class_name = self.default_blocks_name
15511594
if blocks_class_name is not None:
15521595
diffusers_module = importlib.import_module("diffusers")
15531596
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1619,9 +1662,6 @@ def default_call_parameters(self) -> Dict[str, Any]:
16191662
params[input_param.name] = input_param.default
16201663
return params
16211664

1622-
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
1623-
return self.default_blocks_name
1624-
16251665
@classmethod
16261666
def _load_pipeline_config(
16271667
cls,
@@ -1717,7 +1757,8 @@ def from_pretrained(
17171757
logger.debug(" try to determine the modular pipeline class from model_index.json")
17181758
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
17191759
model_name = _get_model(standard_pipeline_class.__name__)
1720-
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
1760+
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
1761+
pipeline_class_name = map_fn(config_dict)
17211762
diffusers_module = importlib.import_module("diffusers")
17221763
pipeline_class = getattr(diffusers_module, pipeline_class_name)
17231764
else:

src/diffusers/modular_pipelines/wan/__init__.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121

2222
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2323
else:
24-
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
25-
_import_structure["encoders"] = ["WanTextEncoderStep"]
26-
_import_structure["modular_blocks"] = [
27-
"ALL_BLOCKS",
28-
"Wan22AutoBlocks",
29-
"WanAutoBlocks",
30-
"WanAutoImageEncoderStep",
31-
"WanAutoVaeImageEncoderStep",
24+
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
25+
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
26+
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
27+
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
28+
_import_structure["modular_pipeline"] = [
29+
"Wan22Image2VideoModularPipeline",
30+
"Wan22ModularPipeline",
31+
"WanImage2VideoModularPipeline",
32+
"WanModularPipeline",
3233
]
33-
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
3434

3535
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3636
try:
@@ -39,16 +39,16 @@
3939
except OptionalDependencyNotAvailable:
4040
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
4141
else:
42-
from .decoders import WanImageVaeDecoderStep
43-
from .encoders import WanTextEncoderStep
44-
from .modular_blocks import (
45-
ALL_BLOCKS,
46-
Wan22AutoBlocks,
47-
WanAutoBlocks,
48-
WanAutoImageEncoderStep,
49-
WanAutoVaeImageEncoderStep,
42+
from .modular_blocks_wan import WanBlocks
43+
from .modular_blocks_wan22 import Wan22Blocks
44+
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
45+
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
46+
from .modular_pipeline import (
47+
Wan22Image2VideoModularPipeline,
48+
Wan22ModularPipeline,
49+
WanImage2VideoModularPipeline,
50+
WanModularPipeline,
5051
)
51-
from .modular_pipeline import WanModularPipeline
5252
else:
5353
import sys
5454

0 commit comments

Comments
 (0)