Skip to content

Commit b1f2fd1

Browse files
committed
update
1 parent d9c4f1e commit b1f2fd1

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

tests/pipelines/test_pipelines_common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,25 +1378,27 @@ def test_float16_inference(self, expected_max_diff=5e-2):
13781378
for component in pipe_fp16.components.values():
13791379
if hasattr(component, "set_default_attn_processor"):
13801380
component.set_default_attn_processor()
1381-
13821381
pipe_fp16.to(torch_device, torch.float16)
13831382
pipe_fp16.set_progress_bar_config(disable=None)
13841383

13851384
inputs = self.get_dummy_inputs(torch_device)
13861385
# Reset generator in case it is used inside dummy inputs
13871386
if "generator" in inputs:
13881387
inputs["generator"] = self.get_generator(0)
1389-
1390-
output = pipe(**inputs)[0].cpu()
1388+
output = pipe(**inputs)[0]
13911389

13921390
fp16_inputs = self.get_dummy_inputs(torch_device)
13931391
# Reset generator in case it is used inside dummy inputs
13941392
if "generator" in fp16_inputs:
13951393
fp16_inputs["generator"] = self.get_generator(0)
1394+
output_fp16 = pipe_fp16(**fp16_inputs)[0]
1395+
1396+
if isinstance(output, torch.Tensor):
1397+
output = output.cpu()
1398+
output_fp16 = output_fp16.cpu()
13961399

1397-
output_fp16 = pipe_fp16(**fp16_inputs)[0].cpu()
13981400
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
1399-
assert max_diff < 1e-2
1401+
assert max_diff < expected_max_diff
14001402

14011403
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
14021404
@require_accelerator

0 commit comments

Comments
 (0)