Skip to content

Commit 3918824

Browse files
yiyixuxuyiyi@huggingface.co
andauthored
[modular] fallback to default_blocks_name when loading base block classes in ModularPipeline (#13193)
up Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
1 parent 9b97932 commit 3918824

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1633,7 +1633,14 @@ def __init__(
16331633
blocks_class_name = self.default_blocks_name
16341634
if blocks_class_name is not None:
16351635
diffusers_module = importlib.import_module("diffusers")
1636-
blocks_class = getattr(diffusers_module, blocks_class_name)
1636+
blocks_class = getattr(diffusers_module, blocks_class_name, None)
1637+
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
1638+
# fall back to default_blocks_name
1639+
if blocks_class is None or not blocks_class.block_classes:
1640+
blocks_class_name = self.default_blocks_name
1641+
blocks_class = getattr(diffusers_module, blocks_class_name)
1642+
1643+
if blocks_class is not None:
16371644
blocks = blocks_class()
16381645
else:
16391646
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,27 @@ def test_load_components_skips_invalid_pretrained_path(self):
728728

729729
# Verify test_component was not loaded
730730
assert not hasattr(pipe, "test_component") or pipe.test_component is None
731+
732+
733+
class TestModularPipelineInitFallback:
734+
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
735+
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""
736+
737+
def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path):
738+
# 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks)
739+
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
740+
t2i_blocks = pipe.blocks.get_workflow("text2image")
741+
assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks"
742+
743+
# 2. Use init_pipeline to create a new pipeline from the workflow blocks
744+
t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
745+
746+
# 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks"
747+
save_dir = str(tmp_path / "pipeline")
748+
t2i_pipe.save_pretrained(save_dir)
749+
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
750+
751+
# 4. Verify it fell back to default_blocks_name and has correct blocks
752+
assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__
753+
assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__
754+
assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks)

0 commit comments

Comments
 (0)