@@ -1378,25 +1378,27 @@ def test_float16_inference(self, expected_max_diff=5e-2):
13781378 for component in pipe_fp16 .components .values ():
13791379 if hasattr (component , "set_default_attn_processor" ):
13801380 component .set_default_attn_processor ()
1381-
13821381 pipe_fp16 .to (torch_device , torch .float16 )
13831382 pipe_fp16 .set_progress_bar_config (disable = None )
13841383
13851384 inputs = self .get_dummy_inputs (torch_device )
13861385 # Reset generator in case it is used inside dummy inputs
13871386 if "generator" in inputs :
13881387 inputs ["generator" ] = self .get_generator (0 )
1389-
1390- output = pipe (** inputs )[0 ].cpu ()
1388+ output = pipe (** inputs )[0 ]
13911389
13921390 fp16_inputs = self .get_dummy_inputs (torch_device )
13931391 # Reset generator in case it is used inside dummy inputs
13941392 if "generator" in fp16_inputs :
13951393 fp16_inputs ["generator" ] = self .get_generator (0 )
1394+ output_fp16 = pipe_fp16 (** fp16_inputs )[0 ]
1395+
1396+ if isinstance (output , torch .Tensor ):
1397+ output = output .cpu ()
1398+ output_fp16 = output_fp16 .cpu ()
13961399
1397- output_fp16 = pipe_fp16 (** fp16_inputs )[0 ].cpu ()
13981400 max_diff = numpy_cosine_similarity_distance (output .flatten (), output_fp16 .flatten ())
1399- assert max_diff < 1e-2
1401+ assert max_diff < expected_max_diff
14001402
14011403 @unittest .skipIf (torch_device not in ["cuda" , "xpu" ], reason = "float16 requires CUDA or XPU" )
14021404 @require_accelerator
0 commit comments