2323 slow ,
2424 torch_device ,
2525)
26- from ..test_pipelines_common import (
26+ from ..testing_utils import (
27+ BasePipelineTesterConfig ,
2728 FasterCacheTesterMixin ,
2829 FirstBlockCacheTesterMixin ,
2930 FluxIPAdapterTesterMixin ,
3031 MagCacheTesterMixin ,
32+ MemoryTesterMixin ,
3133 PipelineTesterMixin ,
3234 PyramidAttentionBroadcastTesterMixin ,
3335 TaylorSeerCacheTesterMixin ,
3436 check_qkv_fused_layers_exist ,
3537)
3638
3739
38- class FluxPipelineFastTests (
39- PipelineTesterMixin ,
40- FluxIPAdapterTesterMixin ,
41- PyramidAttentionBroadcastTesterMixin ,
42- FasterCacheTesterMixin ,
43- FirstBlockCacheTesterMixin ,
44- TaylorSeerCacheTesterMixin ,
45- MagCacheTesterMixin ,
46- unittest .TestCase ,
47- ):
48- pipeline_class = FluxPipeline
49- params = frozenset (["prompt" , "height" , "width" , "guidance_scale" , "prompt_embeds" , "pooled_prompt_embeds" ])
50- batch_params = frozenset (["prompt" ])
40+ class FluxPipelineTesterConfig (BasePipelineTesterConfig ):
41+ @property
42+ def pipeline_class (self ):
43+ return FluxPipeline
5144
52- # there is no xformers processor for Flux
53- test_xformers_attention = False
54- test_layerwise_casting = True
55- test_group_offloading = True
45+ @property
46+ def params (self ):
47+ return frozenset (["prompt" , "height" , "width" , "guidance_scale" , "prompt_embeds" , "pooled_prompt_embeds" ])
5648
57- faster_cache_config = FasterCacheConfig (
58- spatial_attention_block_skip_range = 2 ,
59- spatial_attention_timestep_skip_range = (- 1 , 901 ),
60- unconditional_batch_skip_range = 2 ,
61- attention_weight_callback = lambda _ : 0.5 ,
62- is_guidance_distilled = True ,
63- )
49+ @property
50+ def batch_params (self ):
51+ return frozenset (["prompt" ])
52+
53+ @property
54+ def test_layerwise_casting (self ):
55+ return True
56+
57+ @property
58+ def test_group_offloading (self ):
59+ return True
6460
6561 def get_dummy_components (self , num_layers : int = 1 , num_single_layers : int = 1 ):
6662 torch .manual_seed (0 )
@@ -146,6 +142,8 @@ def get_dummy_inputs(self, device, seed=0):
146142 }
147143 return inputs
148144
145+
146+ class TestFluxPipeline (FluxPipelineTesterConfig , PipelineTesterMixin ):
149147 def test_flux_different_prompts (self ):
150148 pipe = self .pipeline_class (** self .get_dummy_components ()).to (torch_device )
151149
@@ -160,7 +158,7 @@ def test_flux_different_prompts(self):
160158
161159 # Outputs should be different here
162160 # For some reasons, they don't show large differences
163- self . assertGreater ( max_diff , 1e-6 , "Outputs should be different for different prompts." )
161+ assert max_diff > 1e-6 , "Outputs should be different for different prompts."
164162
165163 def test_fused_qkv_projections (self ):
166164 device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -176,9 +174,8 @@ def test_fused_qkv_projections(self):
176174 # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
177175 # to the pipeline level.
178176 pipe .transformer .fuse_qkv_projections ()
179- self .assertTrue (
180- check_qkv_fused_layers_exist (pipe .transformer , ["to_qkv" ]),
181- ("Something wrong with the fused attention layers. Expected all the attention projections to be fused." ),
177+ assert check_qkv_fused_layers_exist (pipe .transformer , ["to_qkv" ]), (
178+ "Something wrong with the fused attention layers. Expected all the attention projections to be fused."
182179 )
183180
184181 inputs = self .get_dummy_inputs (device )
@@ -190,17 +187,14 @@ def test_fused_qkv_projections(self):
190187 image = pipe (** inputs ).images
191188 image_slice_disabled = image [0 , - 3 :, - 3 :, - 1 ]
192189
193- self .assertTrue (
194- np .allclose (original_image_slice , image_slice_fused , atol = 1e-3 , rtol = 1e-3 ),
195- ("Fusion of QKV projections shouldn't affect the outputs." ),
190+ assert np .allclose (original_image_slice , image_slice_fused , atol = 1e-3 , rtol = 1e-3 ), (
191+ "Fusion of QKV projections shouldn't affect the outputs."
196192 )
197- self .assertTrue (
198- np .allclose (image_slice_fused , image_slice_disabled , atol = 1e-3 , rtol = 1e-3 ),
199- ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." ),
193+ assert np .allclose (image_slice_fused , image_slice_disabled , atol = 1e-3 , rtol = 1e-3 ), (
194+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
200195 )
201- self .assertTrue (
202- np .allclose (original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2 ),
203- ("Original outputs should match when fused QKV projections are disabled." ),
196+ assert np .allclose (original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2 ), (
197+ "Original outputs should match when fused QKV projections are disabled."
204198 )
205199
206200 def test_flux_image_output_shape (self ):
@@ -215,10 +209,8 @@ def test_flux_image_output_shape(self):
215209 inputs .update ({"height" : height , "width" : width })
216210 image = pipe (** inputs ).images [0 ]
217211 output_height , output_width , _ = image .shape
218- self .assertEqual (
219- (output_height , output_width ),
220- (expected_height , expected_width ),
221- f"Output shape { image .shape } does not match expected shape { (expected_height , expected_width )} " ,
212+ assert (output_height , output_width ) == (expected_height , expected_width ), (
213+ f"Output shape { image .shape } does not match expected shape { (expected_height , expected_width )} "
222214 )
223215
224216 def test_flux_true_cfg (self ):
@@ -230,11 +222,48 @@ def test_flux_true_cfg(self):
230222 inputs ["negative_prompt" ] = "bad quality"
231223 inputs ["true_cfg_scale" ] = 2.0
232224 true_cfg_out = pipe (** inputs , generator = torch .manual_seed (0 )).images [0 ]
233- self . assertFalse (
234- np . allclose ( no_true_cfg_out , true_cfg_out ), "Outputs should be different when true_cfg_scale is set."
225+ assert not np . allclose ( no_true_cfg_out , true_cfg_out ), (
226+ "Outputs should be different when true_cfg_scale is set."
235227 )
236228
237229
230+ class TestFluxPipelineMemory (FluxPipelineTesterConfig , MemoryTesterMixin ):
231+ """Offload / device-map / group-offload / layerwise-casting tests for Flux."""
232+
233+
234+ class TestFluxPipelineIPAdapter (FluxPipelineTesterConfig , FluxIPAdapterTesterMixin ):
235+ """IP-Adapter tests for Flux."""
236+
237+
238+ class TestFluxPipelinePAB (FluxPipelineTesterConfig , PyramidAttentionBroadcastTesterMixin ):
239+ """Pyramid Attention Broadcast cache tests for Flux."""
240+
241+
242+ class TestFluxPipelineFasterCache (FluxPipelineTesterConfig , FasterCacheTesterMixin ):
243+ """FasterCache tests for Flux."""
244+
245+ # Flux is guidance distilled, so we set `is_guidance_distilled=True`.
246+ faster_cache_config = FasterCacheConfig (
247+ spatial_attention_block_skip_range = 2 ,
248+ spatial_attention_timestep_skip_range = (- 1 , 901 ),
249+ unconditional_batch_skip_range = 2 ,
250+ attention_weight_callback = lambda _ : 0.5 ,
251+ is_guidance_distilled = True ,
252+ )
253+
254+
255+ class TestFluxPipelineFirstBlockCache (FluxPipelineTesterConfig , FirstBlockCacheTesterMixin ):
256+ """FirstBlockCache tests for Flux."""
257+
258+
259+ class TestFluxPipelineTaylorSeerCache (FluxPipelineTesterConfig , TaylorSeerCacheTesterMixin ):
260+ """TaylorSeerCache tests for Flux."""
261+
262+
263+ class TestFluxPipelineMagCache (FluxPipelineTesterConfig , MagCacheTesterMixin ):
264+ """MagCache tests for Flux."""
265+
266+
238267@nightly
239268@require_big_accelerator
240269class FluxPipelineSlowTests (unittest .TestCase ):
@@ -293,9 +322,7 @@ def test_flux_inference(self):
293322 # fmt: on
294323
295324 max_diff = numpy_cosine_similarity_distance (expected_slice .flatten (), image_slice .flatten ())
296- self .assertLess (
297- max_diff , 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
298- )
325+ assert max_diff < 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
299326
300327
301328@slow
@@ -373,6 +400,4 @@ def test_flux_ip_adapter_inference(self):
373400 # fmt: on
374401
375402 max_diff = numpy_cosine_similarity_distance (expected_slice .flatten (), image_slice .flatten ())
376- self .assertLess (
377- max_diff , 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
378- )
403+ assert max_diff < 1e-4 , f"Image slice is different from expected slice: { image_slice } != { expected_slice } "
0 commit comments