|
54 | 54 |
|
55 | 55 |
|
56 | 56 | # map regular pipeline to modular pipeline class name |
| 57 | + |
| 58 | + |
| 59 | +def _create_default_map_fn(pipeline_class_name: str): |
| 60 | + """Create a mapping function that always returns the same pipeline class.""" |
| 61 | + |
| 62 | + def _map_fn(config_dict=None): |
| 63 | + return pipeline_class_name |
| 64 | + |
| 65 | + return _map_fn |
| 66 | + |
| 67 | + |
| 68 | +def _flux2_klein_map_fn(config_dict=None): |
| 69 | + if config_dict is None: |
| 70 | + return "Flux2KleinModularPipeline" |
| 71 | + |
| 72 | + if "is_distilled" in config_dict and config_dict["is_distilled"]: |
| 73 | + return "Flux2KleinModularPipeline" |
| 74 | + else: |
| 75 | + return "Flux2KleinBaseModularPipeline" |
| 76 | + |
| 77 | + |
| 78 | +def _wan_map_fn(config_dict=None): |
| 79 | + if config_dict is None: |
| 80 | + return "WanModularPipeline" |
| 81 | + |
| 82 | + if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: |
| 83 | + return "Wan22ModularPipeline" |
| 84 | + else: |
| 85 | + return "WanModularPipeline" |
| 86 | + |
| 87 | + |
| 88 | +def _wan_i2v_map_fn(config_dict=None): |
| 89 | + if config_dict is None: |
| 90 | + return "WanImage2VideoModularPipeline" |
| 91 | + |
| 92 | + if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: |
| 93 | + return "Wan22Image2VideoModularPipeline" |
| 94 | + else: |
| 95 | + return "WanImage2VideoModularPipeline" |
| 96 | + |
| 97 | + |
57 | 98 | MODULAR_PIPELINE_MAPPING = OrderedDict( |
58 | 99 | [ |
59 | | - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), |
60 | | - ("wan", "WanModularPipeline"), |
61 | | - ("flux", "FluxModularPipeline"), |
62 | | - ("flux-kontext", "FluxKontextModularPipeline"), |
63 | | - ("flux2", "Flux2ModularPipeline"), |
64 | | - ("flux2-klein", "Flux2KleinModularPipeline"), |
65 | | - ("qwenimage", "QwenImageModularPipeline"), |
66 | | - ("qwenimage-edit", "QwenImageEditModularPipeline"), |
67 | | - ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), |
68 | | - ("qwenimage-layered", "QwenImageLayeredModularPipeline"), |
69 | | - ("z-image", "ZImageModularPipeline"), |
| 100 | + ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), |
| 101 | + ("wan", _wan_map_fn), |
| 102 | + ("wan-i2v", _wan_i2v_map_fn), |
| 103 | + ("flux", _create_default_map_fn("FluxModularPipeline")), |
| 104 | + ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), |
| 105 | + ("flux2", _create_default_map_fn("Flux2ModularPipeline")), |
| 106 | + ("flux2-klein", _flux2_klein_map_fn), |
| 107 | + ("qwenimage", _create_default_map_fn("QwenImageModularPipeline")), |
| 108 | + ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), |
| 109 | + ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), |
| 110 | + ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), |
| 111 | + ("z-image", _create_default_map_fn("ZImageModularPipeline")), |
70 | 112 | ] |
71 | 113 | ) |
72 | 114 |
|
@@ -368,7 +410,8 @@ def init_pipeline( |
368 | 410 | """ |
369 | 411 | create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub. |
370 | 412 | """ |
371 | | - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) |
| 413 | + map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline")) |
| 414 | + pipeline_class_name = map_fn() |
372 | 415 | diffusers_module = importlib.import_module("diffusers") |
373 | 416 | pipeline_class = getattr(diffusers_module, pipeline_class_name) |
374 | 417 |
|
@@ -1547,7 +1590,7 @@ def __init__( |
1547 | 1590 | if modular_config_dict is not None: |
1548 | 1591 | blocks_class_name = modular_config_dict.get("_blocks_class_name") |
1549 | 1592 | else: |
1550 | | - blocks_class_name = self.get_default_blocks_name(config_dict) |
| 1593 | + blocks_class_name = self.default_blocks_name |
1551 | 1594 | if blocks_class_name is not None: |
1552 | 1595 | diffusers_module = importlib.import_module("diffusers") |
1553 | 1596 | blocks_class = getattr(diffusers_module, blocks_class_name) |
@@ -1619,9 +1662,6 @@ def default_call_parameters(self) -> Dict[str, Any]: |
1619 | 1662 | params[input_param.name] = input_param.default |
1620 | 1663 | return params |
1621 | 1664 |
|
1622 | | - def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: |
1623 | | - return self.default_blocks_name |
1624 | | - |
1625 | 1665 | @classmethod |
1626 | 1666 | def _load_pipeline_config( |
1627 | 1667 | cls, |
@@ -1717,7 +1757,8 @@ def from_pretrained( |
1717 | 1757 | logger.debug(" try to determine the modular pipeline class from model_index.json") |
1718 | 1758 | standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) |
1719 | 1759 | model_name = _get_model(standard_pipeline_class.__name__) |
1720 | | - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) |
| 1760 | + map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline")) |
| 1761 | + pipeline_class_name = map_fn(config_dict) |
1721 | 1762 | diffusers_module = importlib.import_module("diffusers") |
1722 | 1763 | pipeline_class = getattr(diffusers_module, pipeline_class_name) |
1723 | 1764 | else: |
|
0 commit comments