Skip to content

Commit 85ffcf1

Browse files
authored
[tests] Tests for conditional pipeline blocks (#13247)
* implement test suite for conditional blocks. * remove * another fix. * Revert "another fix." This reverts commit ab07b60.
1 parent cbf4d9a commit 85ffcf1

File tree

3 files changed

+358
-118
lines changed

3 files changed

+358
-118
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from diffusers.modular_pipelines import (
17+
AutoPipelineBlocks,
18+
ConditionalPipelineBlocks,
19+
InputParam,
20+
ModularPipelineBlocks,
21+
)
22+
23+
24+
class TextToImageBlock(ModularPipelineBlocks):
25+
model_name = "text2img"
26+
27+
@property
28+
def inputs(self):
29+
return [InputParam(name="prompt")]
30+
31+
@property
32+
def intermediate_outputs(self):
33+
return []
34+
35+
@property
36+
def description(self):
37+
return "text-to-image workflow"
38+
39+
def __call__(self, components, state):
40+
block_state = self.get_block_state(state)
41+
block_state.workflow = "text2img"
42+
self.set_block_state(state, block_state)
43+
return components, state
44+
45+
46+
class ImageToImageBlock(ModularPipelineBlocks):
47+
model_name = "img2img"
48+
49+
@property
50+
def inputs(self):
51+
return [InputParam(name="prompt"), InputParam(name="image")]
52+
53+
@property
54+
def intermediate_outputs(self):
55+
return []
56+
57+
@property
58+
def description(self):
59+
return "image-to-image workflow"
60+
61+
def __call__(self, components, state):
62+
block_state = self.get_block_state(state)
63+
block_state.workflow = "img2img"
64+
self.set_block_state(state, block_state)
65+
return components, state
66+
67+
68+
class InpaintBlock(ModularPipelineBlocks):
69+
model_name = "inpaint"
70+
71+
@property
72+
def inputs(self):
73+
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
74+
75+
@property
76+
def intermediate_outputs(self):
77+
return []
78+
79+
@property
80+
def description(self):
81+
return "inpaint workflow"
82+
83+
def __call__(self, components, state):
84+
block_state = self.get_block_state(state)
85+
block_state.workflow = "inpaint"
86+
self.set_block_state(state, block_state)
87+
return components, state
88+
89+
90+
class ConditionalImageBlocks(ConditionalPipelineBlocks):
91+
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
92+
block_names = ["inpaint", "img2img", "text2img"]
93+
block_trigger_inputs = ["mask", "image"]
94+
default_block_name = "text2img"
95+
96+
@property
97+
def description(self):
98+
return "Conditional image blocks for testing"
99+
100+
def select_block(self, mask=None, image=None) -> str | None:
101+
if mask is not None:
102+
return "inpaint"
103+
if image is not None:
104+
return "img2img"
105+
return None # falls back to default_block_name
106+
107+
108+
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
109+
block_classes = [InpaintBlock, ImageToImageBlock]
110+
block_names = ["inpaint", "img2img"]
111+
block_trigger_inputs = ["mask", "image"]
112+
default_block_name = None # no default; block can be skipped
113+
114+
@property
115+
def description(self):
116+
return "Optional conditional blocks (skippable)"
117+
118+
def select_block(self, mask=None, image=None) -> str | None:
119+
if mask is not None:
120+
return "inpaint"
121+
if image is not None:
122+
return "img2img"
123+
return None
124+
125+
126+
class AutoImageBlocks(AutoPipelineBlocks):
127+
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
128+
block_names = ["inpaint", "img2img", "text2img"]
129+
block_trigger_inputs = ["mask", "image", None]
130+
131+
@property
132+
def description(self):
133+
return "Auto image blocks for testing"
134+
135+
136+
class TestConditionalPipelineBlocksSelectBlock:
137+
def test_select_block_with_mask(self):
138+
blocks = ConditionalImageBlocks()
139+
assert blocks.select_block(mask="something") == "inpaint"
140+
141+
def test_select_block_with_image(self):
142+
blocks = ConditionalImageBlocks()
143+
assert blocks.select_block(image="something") == "img2img"
144+
145+
def test_select_block_with_mask_and_image(self):
146+
blocks = ConditionalImageBlocks()
147+
assert blocks.select_block(mask="m", image="i") == "inpaint"
148+
149+
def test_select_block_no_triggers_returns_none(self):
150+
blocks = ConditionalImageBlocks()
151+
assert blocks.select_block() is None
152+
153+
def test_select_block_explicit_none_values(self):
154+
blocks = ConditionalImageBlocks()
155+
assert blocks.select_block(mask=None, image=None) is None
156+
157+
158+
class TestConditionalPipelineBlocksWorkflowSelection:
159+
def test_default_workflow_when_no_triggers(self):
160+
blocks = ConditionalImageBlocks()
161+
execution = blocks.get_execution_blocks()
162+
assert execution is not None
163+
assert isinstance(execution, TextToImageBlock)
164+
165+
def test_mask_trigger_selects_inpaint(self):
166+
blocks = ConditionalImageBlocks()
167+
execution = blocks.get_execution_blocks(mask=True)
168+
assert isinstance(execution, InpaintBlock)
169+
170+
def test_image_trigger_selects_img2img(self):
171+
blocks = ConditionalImageBlocks()
172+
execution = blocks.get_execution_blocks(image=True)
173+
assert isinstance(execution, ImageToImageBlock)
174+
175+
def test_mask_and_image_selects_inpaint(self):
176+
blocks = ConditionalImageBlocks()
177+
execution = blocks.get_execution_blocks(mask=True, image=True)
178+
assert isinstance(execution, InpaintBlock)
179+
180+
def test_skippable_block_returns_none(self):
181+
blocks = OptionalConditionalBlocks()
182+
execution = blocks.get_execution_blocks()
183+
assert execution is None
184+
185+
def test_skippable_block_still_selects_when_triggered(self):
186+
blocks = OptionalConditionalBlocks()
187+
execution = blocks.get_execution_blocks(image=True)
188+
assert isinstance(execution, ImageToImageBlock)
189+
190+
191+
class TestAutoPipelineBlocksSelectBlock:
192+
def test_auto_select_mask(self):
193+
blocks = AutoImageBlocks()
194+
assert blocks.select_block(mask="m") == "inpaint"
195+
196+
def test_auto_select_image(self):
197+
blocks = AutoImageBlocks()
198+
assert blocks.select_block(image="i") == "img2img"
199+
200+
def test_auto_select_default(self):
201+
blocks = AutoImageBlocks()
202+
# No trigger -> returns None -> falls back to default (text2img)
203+
assert blocks.select_block() is None
204+
205+
def test_auto_select_priority_order(self):
206+
blocks = AutoImageBlocks()
207+
assert blocks.select_block(mask="m", image="i") == "inpaint"
208+
209+
210+
class TestAutoPipelineBlocksWorkflowSelection:
211+
def test_auto_default_workflow(self):
212+
blocks = AutoImageBlocks()
213+
execution = blocks.get_execution_blocks()
214+
assert isinstance(execution, TextToImageBlock)
215+
216+
def test_auto_mask_workflow(self):
217+
blocks = AutoImageBlocks()
218+
execution = blocks.get_execution_blocks(mask=True)
219+
assert isinstance(execution, InpaintBlock)
220+
221+
def test_auto_image_workflow(self):
222+
blocks = AutoImageBlocks()
223+
execution = blocks.get_execution_blocks(image=True)
224+
assert isinstance(execution, ImageToImageBlock)
225+
226+
227+
class TestConditionalPipelineBlocksStructure:
228+
def test_block_names_accessible(self):
229+
blocks = ConditionalImageBlocks()
230+
sub = dict(blocks.sub_blocks)
231+
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
232+
233+
def test_sub_block_types(self):
234+
blocks = ConditionalImageBlocks()
235+
sub = dict(blocks.sub_blocks)
236+
assert isinstance(sub["inpaint"], InpaintBlock)
237+
assert isinstance(sub["img2img"], ImageToImageBlock)
238+
assert isinstance(sub["text2img"], TextToImageBlock)
239+
240+
def test_description(self):
241+
blocks = ConditionalImageBlocks()
242+
assert "Conditional" in blocks.description

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@
1010
import diffusers
1111
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
1212
from diffusers.guiders import ClassifierFreeGuidance
13-
from diffusers.modular_pipelines import (
14-
ConditionalPipelineBlocks,
15-
LoopSequentialPipelineBlocks,
16-
SequentialPipelineBlocks,
17-
)
1813
from diffusers.modular_pipelines.modular_pipeline_utils import (
1914
ComponentSpec,
2015
ConfigSpec,
@@ -25,7 +20,6 @@
2520
from diffusers.utils import logging
2621

2722
from ..testing_utils import (
28-
CaptureLogger,
2923
backend_empty_cache,
3024
numpy_cosine_similarity_distance,
3125
require_accelerator,
@@ -498,117 +492,6 @@ def test_guider_cfg(self, expected_max_diff=1e-2):
498492
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
499493

500494

501-
class TestCustomBlockRequirements:
502-
def get_dummy_block_pipe(self):
503-
class DummyBlockOne:
504-
# keep two arbitrary deps so that we can test warnings.
505-
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
506-
507-
class DummyBlockTwo:
508-
# keep two dependencies that will be available during testing.
509-
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
510-
511-
pipe = SequentialPipelineBlocks.from_blocks_dict(
512-
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
513-
)
514-
return pipe
515-
516-
def get_dummy_conditional_block_pipe(self):
517-
class DummyBlockOne:
518-
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
519-
520-
class DummyBlockTwo:
521-
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
522-
523-
class DummyConditionalBlocks(ConditionalPipelineBlocks):
524-
block_classes = [DummyBlockOne, DummyBlockTwo]
525-
block_names = ["block_one", "block_two"]
526-
block_trigger_inputs = []
527-
528-
def select_block(self, **kwargs):
529-
return "block_one"
530-
531-
return DummyConditionalBlocks()
532-
533-
def get_dummy_loop_block_pipe(self):
534-
class DummyBlockOne:
535-
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
536-
537-
class DummyBlockTwo:
538-
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
539-
540-
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
541-
542-
def test_sequential_block_requirements_save_load(self, tmp_path):
543-
pipe = self.get_dummy_block_pipe()
544-
pipe.save_pretrained(str(tmp_path))
545-
546-
config_path = tmp_path / "modular_config.json"
547-
548-
with open(config_path, "r") as f:
549-
config = json.load(f)
550-
551-
assert "requirements" in config
552-
requirements = config["requirements"]
553-
554-
expected_requirements = {
555-
"xyz": ">=0.8.0",
556-
"abc": ">=10.0.0",
557-
"transformers": ">=4.44.0",
558-
"diffusers": ">=0.2.0",
559-
}
560-
assert expected_requirements == requirements
561-
562-
def test_sequential_block_requirements_warnings(self, tmp_path):
563-
pipe = self.get_dummy_block_pipe()
564-
565-
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
566-
logger.setLevel(30)
567-
568-
with CaptureLogger(logger) as cap_logger:
569-
pipe.save_pretrained(str(tmp_path))
570-
571-
template = "{req} was specified in the requirements but wasn't found in the current environment"
572-
msg_xyz = template.format(req="xyz")
573-
msg_abc = template.format(req="abc")
574-
assert msg_xyz in str(cap_logger.out)
575-
assert msg_abc in str(cap_logger.out)
576-
577-
def test_conditional_block_requirements_save_load(self, tmp_path):
578-
pipe = self.get_dummy_conditional_block_pipe()
579-
pipe.save_pretrained(str(tmp_path))
580-
581-
config_path = tmp_path / "modular_config.json"
582-
with open(config_path, "r") as f:
583-
config = json.load(f)
584-
585-
assert "requirements" in config
586-
expected_requirements = {
587-
"xyz": ">=0.8.0",
588-
"abc": ">=10.0.0",
589-
"transformers": ">=4.44.0",
590-
"diffusers": ">=0.2.0",
591-
}
592-
assert expected_requirements == config["requirements"]
593-
594-
def test_loop_block_requirements_save_load(self, tmp_path):
595-
pipe = self.get_dummy_loop_block_pipe()
596-
pipe.save_pretrained(str(tmp_path))
597-
598-
config_path = tmp_path / "modular_config.json"
599-
with open(config_path, "r") as f:
600-
config = json.load(f)
601-
602-
assert "requirements" in config
603-
expected_requirements = {
604-
"xyz": ">=0.8.0",
605-
"abc": ">=10.0.0",
606-
"transformers": ">=4.44.0",
607-
"diffusers": ">=0.2.0",
608-
}
609-
assert expected_requirements == config["requirements"]
610-
611-
612495
class TestModularModelCardContent:
613496
def create_mock_block(self, name="TestBlock", description="Test block description"):
614497
class MockBlock:

0 commit comments

Comments
 (0)