2525
2626from ...models .unets .test_models_unet_2d_condition import create_ip_adapter_state_dict
2727from ...testing_utils import enable_full_determinism , floats_tensor , torch_device
28- from ..test_modular_pipelines_common import ModularPipelineTesterMixin
28+ from ..test_modular_pipelines_common import ModularGuiderTesterMixin , ModularPipelineTesterMixin
2929
3030
3131enable_full_determinism ()
@@ -37,13 +37,11 @@ class SDXLModularTesterMixin:
3737 """
3838
3939 def _test_stable_diffusion_xl_euler (self , expected_image_shape , expected_slice , expected_max_diff = 1e-2 ):
40- sd_pipe = self .get_pipeline ()
41- sd_pipe = sd_pipe .to (torch_device )
42- sd_pipe .set_progress_bar_config (disable = None )
40+ sd_pipe = self .get_pipeline ().to (torch_device )
4341
4442 inputs = self .get_dummy_inputs ()
4543 image = sd_pipe (** inputs , output = "images" )
46- image_slice = image [0 , - 3 :, - 3 :, - 1 ]
44+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]. cpu ()
4745
4846 assert image .shape == expected_image_shape
4947 max_diff = torch .abs (image_slice .flatten () - expected_slice ).max ()
@@ -110,7 +108,7 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
110108 pipe = blocks .init_pipeline (self .repo )
111109 pipe .load_components (torch_dtype = torch .float32 )
112110 pipe = pipe .to (torch_device )
113- pipe . set_progress_bar_config ( disable = None )
111+
114112 cross_attention_dim = pipe .unet .config .get ("cross_attention_dim" )
115113
116114 # forward pass without ip adapter
@@ -219,9 +217,7 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
219217 # compare against static slices and that can be shaky (with a VVVV low probability).
220218 expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
221219
222- pipe = self .get_pipeline ()
223- pipe = pipe .to (torch_device )
224- pipe .set_progress_bar_config (disable = None )
220+ pipe = self .get_pipeline ().to (torch_device )
225221
226222 # forward pass without controlnet
227223 inputs = self .get_dummy_inputs ()
@@ -251,9 +247,7 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
251247 assert max_diff_with_controlnet_scale > 1e-2 , "Output with controlnet must be different from normal inference"
252248
253249 def test_controlnet_cfg (self ):
254- pipe = self .get_pipeline ()
255- pipe = pipe .to (torch_device )
256- pipe .set_progress_bar_config (disable = None )
250+ pipe = self .get_pipeline ().to (torch_device )
257251
258252 # forward pass with CFG not applied
259253 guider = ClassifierFreeGuidance (guidance_scale = 1.0 )
@@ -273,35 +267,11 @@ def test_controlnet_cfg(self):
273267 assert max_diff > 1e-2 , "Output with CFG must be different from normal inference"
274268
275269
276- class SDXLModularGuiderTesterMixin :
277- def test_guider_cfg (self ):
278- pipe = self .get_pipeline ()
279- pipe = pipe .to (torch_device )
280- pipe .set_progress_bar_config (disable = None )
281-
282- # forward pass with CFG not applied
283- guider = ClassifierFreeGuidance (guidance_scale = 1.0 )
284- pipe .update_components (guider = guider )
285-
286- inputs = self .get_dummy_inputs ()
287- out_no_cfg = pipe (** inputs , output = "images" )
288-
289- # forward pass with CFG applied
290- guider = ClassifierFreeGuidance (guidance_scale = 7.5 )
291- pipe .update_components (guider = guider )
292- inputs = self .get_dummy_inputs ()
293- out_cfg = pipe (** inputs , output = "images" )
294-
295- assert out_cfg .shape == out_no_cfg .shape
296- max_diff = np .abs (out_cfg - out_no_cfg ).max ()
297- assert max_diff > 1e-2 , "Output with CFG must be different from normal inference"
298-
299-
300270class TestSDXLModularPipelineFast (
301271 SDXLModularTesterMixin ,
302272 SDXLModularIPAdapterTesterMixin ,
303273 SDXLModularControlNetTesterMixin ,
304- SDXLModularGuiderTesterMixin ,
274+ ModularGuiderTesterMixin ,
305275 ModularPipelineTesterMixin ,
306276):
307277 """Test cases for Stable Diffusion XL modular pipeline fast tests."""
@@ -335,18 +305,7 @@ def test_stable_diffusion_xl_euler(self):
335305 self ._test_stable_diffusion_xl_euler (
336306 expected_image_shape = self .expected_image_output_shape ,
337307 expected_slice = torch .tensor (
338- [
339- 0.5966781 ,
340- 0.62939394 ,
341- 0.48465094 ,
342- 0.51573336 ,
343- 0.57593524 ,
344- 0.47035995 ,
345- 0.53410417 ,
346- 0.51436996 ,
347- 0.47313565 ,
348- ],
349- device = torch_device ,
308+ [0.3886 , 0.4685 , 0.4953 , 0.4217 , 0.4317 , 0.3945 , 0.4847 , 0.4704 , 0.4731 ],
350309 ),
351310 expected_max_diff = 1e-2 ,
352311 )
@@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast(
359318 SDXLModularTesterMixin ,
360319 SDXLModularIPAdapterTesterMixin ,
361320 SDXLModularControlNetTesterMixin ,
362- SDXLModularGuiderTesterMixin ,
321+ ModularGuiderTesterMixin ,
363322 ModularPipelineTesterMixin ,
364323):
365324 """Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
@@ -400,20 +359,7 @@ def get_dummy_inputs(self, seed=0):
400359 def test_stable_diffusion_xl_euler (self ):
401360 self ._test_stable_diffusion_xl_euler (
402361 expected_image_shape = self .expected_image_output_shape ,
403- expected_slice = torch .tensor (
404- [
405- 0.56943184 ,
406- 0.4702148 ,
407- 0.48048905 ,
408- 0.6235963 ,
409- 0.551138 ,
410- 0.49629188 ,
411- 0.60031277 ,
412- 0.5688907 ,
413- 0.43996853 ,
414- ],
415- device = torch_device ,
416- ),
362+ expected_slice = torch .tensor ([0.5246 , 0.4466 , 0.444 , 0.3246 , 0.4443 , 0.5108 , 0.5225 , 0.559 , 0.5147 ]),
417363 expected_max_diff = 1e-2 ,
418364 )
419365
@@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests(
425371 SDXLModularTesterMixin ,
426372 SDXLModularIPAdapterTesterMixin ,
427373 SDXLModularControlNetTesterMixin ,
428- SDXLModularGuiderTesterMixin ,
374+ ModularGuiderTesterMixin ,
429375 ModularPipelineTesterMixin ,
430376):
431377 """Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
0 commit comments