2323 slow ,
2424 torch_device ,
2525)
26- from ..testing_utils import (
27- BasePipelineTesterConfig ,
26+ from ..test_pipelines_common import (
2827 FasterCacheTesterMixin ,
2928 FirstBlockCacheTesterMixin ,
3029 FluxIPAdapterTesterMixin ,
3130 MagCacheTesterMixin ,
32- MemoryTesterMixin ,
3331 PipelineTesterMixin ,
3432 PyramidAttentionBroadcastTesterMixin ,
3533 TaylorSeerCacheTesterMixin ,
3634 check_qkv_fused_layers_exist ,
3735)
3836
3937
40- class FluxPipelineTesterConfig (BasePipelineTesterConfig ):
41- @property
42- def pipeline_class (self ):
43- return FluxPipeline
44-
45- @property
46- def params (self ):
47- return frozenset (["prompt" , "height" , "width" , "guidance_scale" , "prompt_embeds" , "pooled_prompt_embeds" ])
48-
49- @property
50- def batch_params (self ):
51- return frozenset (["prompt" ])
38+ class FluxPipelineFastTests (
39+ PipelineTesterMixin ,
40+ FluxIPAdapterTesterMixin ,
41+ PyramidAttentionBroadcastTesterMixin ,
42+ FasterCacheTesterMixin ,
43+ FirstBlockCacheTesterMixin ,
44+ TaylorSeerCacheTesterMixin ,
45+ MagCacheTesterMixin ,
46+ unittest .TestCase ,
47+ ):
48+ pipeline_class = FluxPipeline
49+ params = frozenset (["prompt" , "height" , "width" , "guidance_scale" , "prompt_embeds" , "pooled_prompt_embeds" ])
50+ batch_params = frozenset (["prompt" ])
5251
53- @property
54- def test_layerwise_casting (self ):
55- return True
52+ # there is no xformers processor for Flux
53+ test_xformers_attention = False
54+ test_layerwise_casting = True
55+ test_group_offloading = True
5656
57- @property
58- def test_group_offloading (self ):
59- return True
57+ faster_cache_config = FasterCacheConfig (
58+ spatial_attention_block_skip_range = 2 ,
59+ spatial_attention_timestep_skip_range = (- 1 , 901 ),
60+ unconditional_batch_skip_range = 2 ,
61+ attention_weight_callback = lambda _ : 0.5 ,
62+ is_guidance_distilled = True ,
63+ )
6064
6165 def get_dummy_components (self , num_layers : int = 1 , num_single_layers : int = 1 ):
6266 torch .manual_seed (0 )
@@ -142,8 +146,6 @@ def get_dummy_inputs(self, device, seed=0):
142146 }
143147 return inputs
144148
145-
146- class TestFluxPipeline (FluxPipelineTesterConfig , PipelineTesterMixin ):
147149 def test_flux_different_prompts (self ):
148150 pipe = self .pipeline_class (** self .get_dummy_components ()).to (torch_device )
149151
@@ -158,7 +160,7 @@ def test_flux_different_prompts(self):
158160
159161 # Outputs should be different here
160162 # For some reasons, they don't show large differences
161- assert max_diff > 1e-6 , "Outputs should be different for different prompts."
163+ self . assertGreater ( max_diff , 1e-6 , "Outputs should be different for different prompts." )
162164
163165 def test_fused_qkv_projections (self ):
164166 device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -174,8 +176,9 @@ def test_fused_qkv_projections(self):
174176 # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
175177 # to the pipeline level.
176178 pipe .transformer .fuse_qkv_projections ()
177- assert check_qkv_fused_layers_exist (pipe .transformer , ["to_qkv" ]), (
178- "Something wrong with the fused attention layers. Expected all the attention projections to be fused."
179+ self .assertTrue (
180+ check_qkv_fused_layers_exist (pipe .transformer , ["to_qkv" ]),
181+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused." ),
179182 )
180183
181184 inputs = self .get_dummy_inputs (device )
@@ -187,14 +190,17 @@ def test_fused_qkv_projections(self):
187190 image = pipe (** inputs ).images
188191 image_slice_disabled = image [0 , - 3 :, - 3 :, - 1 ]
189192
190- assert np .allclose (original_image_slice , image_slice_fused , atol = 1e-3 , rtol = 1e-3 ), (
191- "Fusion of QKV projections shouldn't affect the outputs."
193+ self .assertTrue (
194+ np .allclose (original_image_slice , image_slice_fused , atol = 1e-3 , rtol = 1e-3 ),
195+ ("Fusion of QKV projections shouldn't affect the outputs." ),
192196 )
193- assert np .allclose (image_slice_fused , image_slice_disabled , atol = 1e-3 , rtol = 1e-3 ), (
194- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
197+ self .assertTrue (
198+ np .allclose (image_slice_fused , image_slice_disabled , atol = 1e-3 , rtol = 1e-3 ),
199+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." ),
195200 )
196- assert np .allclose (original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2 ), (
197- "Original outputs should match when fused QKV projections are disabled."
201+ self .assertTrue (
202+ np .allclose (original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2 ),
203+ ("Original outputs should match when fused QKV projections are disabled." ),
198204 )
199205
200206 def test_flux_image_output_shape (self ):
@@ -209,8 +215,10 @@ def test_flux_image_output_shape(self):
209215 inputs .update ({"height" : height , "width" : width })
210216 image = pipe (** inputs ).images [0 ]
211217 output_height , output_width , _ = image .shape
212- assert (output_height , output_width ) == (expected_height , expected_width ), (
213- f"Output shape { image .shape } does not match expected shape { (expected_height , expected_width )} "
218+ self .assertEqual (
219+ (output_height , output_width ),
220+ (expected_height , expected_width ),
221+ f"Output shape { image .shape } does not match expected shape { (expected_height , expected_width )} " ,
214222 )
215223
216224 def test_flux_true_cfg (self ):
@@ -222,48 +230,11 @@ def test_flux_true_cfg(self):
222230 inputs ["negative_prompt" ] = "bad quality"
223231 inputs ["true_cfg_scale" ] = 2.0
224232 true_cfg_out = pipe (** inputs , generator = torch .manual_seed (0 )).images [0 ]
225- assert not np . allclose ( no_true_cfg_out , true_cfg_out ), (
226- "Outputs should be different when true_cfg_scale is set."
233+ self . assertFalse (
234+ np . allclose ( no_true_cfg_out , true_cfg_out ), "Outputs should be different when true_cfg_scale is set."
227235 )
228236
229237
230- class TestFluxPipelineMemory (FluxPipelineTesterConfig , MemoryTesterMixin ):
231- """Offload / device-map / group-offload / layerwise-casting tests for Flux."""
232-
233-
234- class TestFluxPipelineIPAdapter (FluxPipelineTesterConfig , FluxIPAdapterTesterMixin ):
235- """IP-Adapter tests for Flux."""
236-
237-
238- class TestFluxPipelinePAB (FluxPipelineTesterConfig , PyramidAttentionBroadcastTesterMixin ):
239- """Pyramid Attention Broadcast cache tests for Flux."""
240-
241-
242- class TestFluxPipelineFasterCache (FluxPipelineTesterConfig , FasterCacheTesterMixin ):
243- """FasterCache tests for Flux."""
244-
245- # Flux is guidance distilled, so we set `is_guidance_distilled=True`.
246- faster_cache_config = FasterCacheConfig (
247- spatial_attention_block_skip_range = 2 ,
248- spatial_attention_timestep_skip_range = (- 1 , 901 ),
249- unconditional_batch_skip_range = 2 ,
250- attention_weight_callback = lambda _ : 0.5 ,
251- is_guidance_distilled = True ,
252- )
253-
254-
255- class TestFluxPipelineFirstBlockCache (FluxPipelineTesterConfig , FirstBlockCacheTesterMixin ):
256- """FirstBlockCache tests for Flux."""
257-
258-
259- class TestFluxPipelineTaylorSeerCache (FluxPipelineTesterConfig , TaylorSeerCacheTesterMixin ):
260- """TaylorSeerCache tests for Flux."""
261-
262-
263- class TestFluxPipelineMagCache (FluxPipelineTesterConfig , MagCacheTesterMixin ):
264- """MagCache tests for Flux."""
265-
266-
267238@nightly
268239@require_big_accelerator
269240class FluxPipelineSlowTests (unittest .TestCase ):
@@ -322,7 +293,9 @@ def test_flux_inference(self):
322293 # fmt: on
323294
324295 max_diff = numpy_cosine_similarity_distance (expected_slice .flatten (), image_slice .flatten ())
325- assert max_diff < 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
296+ self .assertLess (
297+ max_diff , 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
298+ )
326299
327300
328301@slow
@@ -400,4 +373,6 @@ def test_flux_ip_adapter_inference(self):
400373 # fmt: on
401374
402375 max_diff = numpy_cosine_similarity_distance (expected_slice .flatten (), image_slice .flatten ())
403- assert max_diff < 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
376+ self .assertLess (
377+ max_diff , 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
378+ )
0 commit comments