Skip to content

Commit 27a1c25

Browse files
committed
align with the latest structure
1 parent 42f0bf7 commit 27a1c25

2 files changed

Lines changed: 51 additions & 66 deletions

File tree

tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
1716

1817
import numpy as np
1918
import PIL
2019
import pytest
21-
import torch
2220

2321
from diffusers import ClassifierFreeGuidance
2422
from diffusers.modular_pipelines import (
@@ -34,38 +32,6 @@
3432
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
3533

3634

37-
class QwenImageModularTests:
38-
pipeline_class = QwenImageModularPipeline
39-
pipeline_blocks_class = QwenImageAutoBlocks
40-
repo = "hf-internal-testing/tiny-qwenimage-modular"
41-
42-
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
43-
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
44-
45-
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
46-
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
47-
pipeline.load_components(torch_dtype=torch_dtype)
48-
pipeline.set_progress_bar_config(disable=None)
49-
return pipeline
50-
51-
def get_dummy_inputs(self, device, seed=0):
52-
if str(device).startswith("mps"):
53-
generator = torch.manual_seed(seed)
54-
else:
55-
generator = torch.Generator(device=device).manual_seed(seed)
56-
inputs = {
57-
"prompt": "dance monkey",
58-
"negative_prompt": "bad quality",
59-
"generator": generator,
60-
"num_inference_steps": 2,
61-
"height": 32,
62-
"width": 32,
63-
"max_sequence_length": 16,
64-
"output_type": "np",
65-
}
66-
return inputs
67-
68-
6935
class QwenImageModularGuiderTests:
7036
def test_guider_cfg(self, tol=1e-2):
7137
pipe = self.get_pipeline()
@@ -87,33 +53,56 @@ def test_guider_cfg(self, tol=1e-2):
8753
assert max_diff > tol, "Output with CFG must be different from normal inference"
8854

8955

90-
class QwenImageModularPipelineFastTests(
91-
QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase
92-
):
93-
def __init__(self, *args, **kwargs):
94-
super().__init__(*args, **kwargs)
56+
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, QwenImageModularGuiderTests):
57+
pipeline_class = QwenImageModularPipeline
58+
pipeline_blocks_class = QwenImageAutoBlocks
59+
repo = "hf-internal-testing/tiny-qwenimage-modular"
60+
61+
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
62+
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
63+
64+
def get_dummy_inputs(self):
65+
generator = self.get_generator()
66+
inputs = {
67+
"prompt": "dance monkey",
68+
"negative_prompt": "bad quality",
69+
"generator": generator,
70+
"num_inference_steps": 2,
71+
"height": 32,
72+
"width": 32,
73+
"max_sequence_length": 16,
74+
"output_type": "np",
75+
}
76+
return inputs
9577

9678

97-
class QwenImageEditModularPipelineFastTests(
98-
QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase
99-
):
79+
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, QwenImageModularGuiderTests):
10080
pipeline_class = QwenImageEditModularPipeline
10181
pipeline_blocks_class = QwenImageEditAutoBlocks
10282
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
10383

104-
def get_dummy_inputs(self, device, seed=0):
105-
inputs = super().get_dummy_inputs(device, seed)
106-
inputs.pop("max_sequence_length")
84+
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
85+
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
86+
87+
def get_dummy_inputs(self):
88+
generator = self.get_generator()
89+
inputs = {
90+
"prompt": "dance monkey",
91+
"negative_prompt": "bad quality",
92+
"generator": generator,
93+
"num_inference_steps": 2,
94+
"height": 32,
95+
"width": 32,
96+
"output_type": "np",
97+
}
10798
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
10899
return inputs
109100

110101
def test_guider_cfg(self):
111102
super().test_guider_cfg(7e-5)
112103

113104

114-
class QwenImageEditPlusModularPipelineFastTests(
115-
QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase
116-
):
105+
class QwenImageEditPlusModularPipelineFastTests(ModularPipelineTesterMixin, QwenImageModularGuiderTests):
117106
pipeline_class = QwenImageEditPlusModularPipeline
118107
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
119108
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
@@ -122,11 +111,18 @@ class QwenImageEditPlusModularPipelineFastTests(
122111
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
123112
batch_params = frozenset(["prompt", "negative_prompt", "image"])
124113

125-
def get_dummy_inputs(self, device, seed=0):
126-
inputs = super().get_dummy_inputs(device, seed)
127-
inputs.pop("max_sequence_length")
128-
image = PIL.Image.new("RGB", (32, 32), 0)
129-
inputs["image"] = [image]
114+
def get_dummy_inputs(self):
115+
generator = self.get_generator()
116+
inputs = {
117+
"prompt": "dance monkey",
118+
"negative_prompt": "bad quality",
119+
"generator": generator,
120+
"num_inference_steps": 2,
121+
"height": 32,
122+
"width": 32,
123+
"output_type": "np",
124+
}
125+
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
130126
return inputs
131127

132128
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,9 @@ class ModularPipelineTesterMixin:
3232
# Canonical parameters that are passed to `__call__` regardless
3333
# of the type of pipeline. They are always optional and have common
3434
# sense default values.
35-
optional_params = frozenset(
36-
[
37-
"num_inference_steps",
38-
"num_images_per_prompt",
39-
"latents",
40-
"output_type",
41-
]
42-
)
35+
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
4336
# this is modular specific: generator needs to be a intermediate input because it's mutable
44-
intermediate_params = frozenset(
45-
[
46-
"generator",
47-
]
48-
)
37+
intermediate_params = frozenset(["generator"])
4938

5039
def get_generator(self, seed=0):
5140
generator = torch.Generator("cpu").manual_seed(seed)

0 commit comments

Comments
 (0)