Skip to content

Commit 944c29b

Browse files
sayakpaulDN6
authored andcommitted
fix to device and to dtype tests. (#13323)
1 parent 70ab180 commit 944c29b

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,14 +1534,18 @@ def test_to_device(self):
15341534
pipe.set_progress_bar_config(disable=None)
15351535

15361536
pipe.to("cpu")
1537-
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
1537+
model_devices = [
1538+
component.device.type for component in components.values() if getattr(component, "device", None)
1539+
]
15381540
self.assertTrue(all(device == "cpu" for device in model_devices))
15391541

15401542
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
15411543
self.assertTrue(np.isnan(output_cpu).sum() == 0)
15421544

15431545
pipe.to(torch_device)
1544-
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
1546+
model_devices = [
1547+
component.device.type for component in components.values() if getattr(component, "device", None)
1548+
]
15451549
self.assertTrue(all(device == torch_device for device in model_devices))
15461550

15471551
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -1552,11 +1556,11 @@ def test_to_dtype(self):
15521556
pipe = self.pipeline_class(**components)
15531557
pipe.set_progress_bar_config(disable=None)
15541558

1555-
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
1559+
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
15561560
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
15571561

15581562
pipe.to(dtype=torch.float16)
1559-
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
1563+
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
15601564
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
15611565

15621566
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

0 commit comments

Comments
 (0)