Skip to content

Commit 2fa7f77

Browse files
sayakpauldg845
andauthored
fix kvae gradient checkpointing tests (#13865)
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 5d10b4d commit 2fa7f77

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def test_training(self):
9797
def test_training_with_ema(self):
9898
_run_nondeterministic(super().test_training_with_ema)
9999

100+
def test_mixed_precision_training(self):
101+
_run_nondeterministic(super().test_mixed_precision_training)
102+
100103
@pytest.mark.skip(
101104
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
102105
"that is mutated during the first forward. On recomputation the cache is already populated, "
@@ -105,13 +108,13 @@ def test_training_with_ema(self):
105108
def test_gradient_checkpointing_equivalence(self):
106109
super().test_gradient_checkpointing_equivalence()
107110

108-
def test_layerwise_casting_training(self):
109-
_run_nondeterministic(super().test_layerwise_casting_training)
110-
111111

112112
class TestAutoencoderKLKVAEVideoMemory(AutoencoderKLKVAEVideoTesterConfig, MemoryTesterMixin):
113113
"""Memory optimization tests for AutoencoderKLKVAEVideo."""
114114

115+
def test_layerwise_casting_training(self):
116+
_run_nondeterministic(super().test_layerwise_casting_training)
117+
115118

116119
class TestAutoencoderKLKVAEVideoSlicingTiling(AutoencoderKLKVAEVideoTesterConfig, NewAutoencoderTesterMixin):
117120
"""Slicing and tiling tests for AutoencoderKLKVAEVideo."""

0 commit comments

Comments
 (0)