File tree Expand file tree Collapse file tree
tests/models/transformers Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import gc
1617import os
1718import unittest
1819
@@ -87,6 +88,25 @@ def prepare_init_args_and_inputs_for_common(self):
8788 inputs_dict = self .dummy_input
8889 return init_dict , inputs_dict
8990
91+ def setUp (self ):
92+ gc .collect ()
93+ if torch .cuda .is_available ():
94+ torch .cuda .empty_cache ()
95+ torch .cuda .synchronize ()
96+ torch .manual_seed (0 )
97+ if torch .cuda .is_available ():
98+ torch .cuda .manual_seed_all (0 )
99+
100+ def tearDown (self ):
101+ super ().tearDown ()
102+ gc .collect ()
103+ if torch .cuda .is_available ():
104+ torch .cuda .empty_cache ()
105+ torch .cuda .synchronize ()
106+ torch .manual_seed (0 )
107+ if torch .cuda .is_available ():
108+ torch .cuda .manual_seed_all (0 )
109+
90110 def test_gradient_checkpointing_is_applied (self ):
91111 expected_set = {"ZImageTransformer2DModel" }
92112 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
You can’t perform that action at this time.
0 commit comments