Skip to content

Commit 8ee24fc

Browse files
committed
up
1 parent 76dbf63 commit 8ee24fc

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

tests/models/transformers/test_models_transformer_z_image.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import os
1718
import unittest
1819

@@ -87,6 +88,25 @@ def prepare_init_args_and_inputs_for_common(self):
8788
inputs_dict = self.dummy_input
8889
return init_dict, inputs_dict
8990

91+
def setUp(self):
92+
gc.collect()
93+
if torch.cuda.is_available():
94+
torch.cuda.empty_cache()
95+
torch.cuda.synchronize()
96+
torch.manual_seed(0)
97+
if torch.cuda.is_available():
98+
torch.cuda.manual_seed_all(0)
99+
100+
def tearDown(self):
101+
super().tearDown()
102+
gc.collect()
103+
if torch.cuda.is_available():
104+
torch.cuda.empty_cache()
105+
torch.cuda.synchronize()
106+
torch.manual_seed(0)
107+
if torch.cuda.is_available():
108+
torch.cuda.manual_seed_all(0)
109+
90110
def test_gradient_checkpointing_is_applied(self):
91111
expected_set = {"ZImageTransformer2DModel"}
92112
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)