@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
15281528 test_fn (torch .float8_e5m2 , torch .float32 )
15291529 test_fn (torch .float8_e4m3fn , torch .bfloat16 )
15301530
1531+ @torch .no_grad ()
15311532 def test_layerwise_casting_inference (self ):
15321533 from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
15331534
15341535 torch .manual_seed (0 )
15351536 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1536- model = self .model_class (** config ).eval ()
1537- model = model .to (torch_device )
1538- base_slice = model (** inputs_dict )[0 ].flatten ().detach ().cpu ().numpy ()
1537+ model = self .model_class (** config )
1538+ model .eval ()
1539+ model .to (torch_device )
1540+ base_slice = model (** inputs_dict )[0 ].detach ().flatten ().cpu ().numpy ()
15391541
15401542 def check_linear_dtype (module , storage_dtype , compute_dtype ):
15411543 patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1706,10 +1708,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17061708 if not self .model_class ._supports_group_offloading :
17071709 pytest .skip ("Model does not support group offloading." )
17081710
1709- torch .manual_seed (0 )
1710- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1711- model = self .model_class (** init_dict )
1712-
17131711 torch .manual_seed (0 )
17141712 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
17151713 model = self .model_class (** init_dict )
@@ -1725,7 +1723,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17251723 ** additional_kwargs ,
17261724 )
17271725 has_safetensors = glob .glob (f"{ tmpdir } /*.safetensors" )
1728- assert has_safetensors , "No safetensors found in the directory."
1726+ self . assertTrue ( len ( has_safetensors ) > 0 , "No safetensors found in the offload directory." )
17291727 _ = model (** inputs_dict )[0 ]
17301728
17311729 def test_auto_model (self , expected_max_diff = 5e-5 ):
0 commit comments