@@ -1534,14 +1534,18 @@ def test_to_device(self):
15341534 pipe .set_progress_bar_config (disable = None )
15351535
15361536 pipe .to ("cpu" )
1537- model_devices = [component .device .type for component in components .values () if hasattr (component , "device" )]
1537+ model_devices = [
1538+ component .device .type for component in components .values () if getattr (component , "device" , None )
1539+ ]
15381540 self .assertTrue (all (device == "cpu" for device in model_devices ))
15391541
15401542 output_cpu = pipe (** self .get_dummy_inputs ("cpu" ))[0 ]
15411543 self .assertTrue (np .isnan (output_cpu ).sum () == 0 )
15421544
15431545 pipe .to (torch_device )
1544- model_devices = [component .device .type for component in components .values () if hasattr (component , "device" )]
1546+ model_devices = [
1547+ component .device .type for component in components .values () if getattr (component , "device" , None )
1548+ ]
15451549 self .assertTrue (all (device == torch_device for device in model_devices ))
15461550
15471551 output_device = pipe (** self .get_dummy_inputs (torch_device ))[0 ]
@@ -1552,11 +1556,11 @@ def test_to_dtype(self):
15521556 pipe = self .pipeline_class (** components )
15531557 pipe .set_progress_bar_config (disable = None )
15541558
1555- model_dtypes = [component .dtype for component in components .values () if hasattr (component , "dtype" )]
1559+ model_dtypes = [component .dtype for component in components .values () if getattr (component , "dtype" , None )]
15561560 self .assertTrue (all (dtype == torch .float32 for dtype in model_dtypes ))
15571561
15581562 pipe .to (dtype = torch .float16 )
1559- model_dtypes = [component .dtype for component in components .values () if hasattr (component , "dtype" )]
1563+ model_dtypes = [component .dtype for component in components .values () if getattr (component , "dtype" , None )]
15601564 self .assertTrue (all (dtype == torch .float16 for dtype in model_dtypes ))
15611565
15621566 def test_attention_slicing_forward_pass (self , expected_max_diff = 1e-3 ):
0 commit comments