|
20 | 20 | from diffusers import AutoencoderKLHunyuanVideo |
21 | 21 | from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask |
22 | 22 |
|
23 | | -from ...testing_utils import ( |
24 | | - enable_full_determinism, |
25 | | - floats_tensor, |
26 | | - torch_device, |
27 | | -) |
28 | | -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin |
| 23 | +from ...testing_utils import enable_full_determinism, floats_tensor, torch_device |
| 24 | +from ..test_modeling_common import ModelTesterMixin |
| 25 | +from .testing_utils import AutoencoderTesterMixin |
29 | 26 |
|
30 | 27 |
|
31 | 28 | enable_full_determinism() |
32 | 29 |
|
33 | 30 |
|
34 | | -class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
| 31 | +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): |
35 | 32 | model_class = AutoencoderKLHunyuanVideo |
36 | 33 | main_input_name = "sample" |
37 | 34 | base_precision = 1e-2 |
@@ -87,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self): |
87 | 84 | inputs_dict = self.dummy_input |
88 | 85 | return init_dict, inputs_dict |
89 | 86 |
|
90 | | - def test_enable_disable_tiling(self): |
91 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
92 | | - |
93 | | - torch.manual_seed(0) |
94 | | - model = self.model_class(**init_dict).to(torch_device) |
95 | | - |
96 | | - inputs_dict.update({"return_dict": False}) |
97 | | - |
98 | | - torch.manual_seed(0) |
99 | | - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
100 | | - |
101 | | - torch.manual_seed(0) |
102 | | - model.enable_tiling() |
103 | | - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
104 | | - |
105 | | - self.assertLess( |
106 | | - (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), |
107 | | - 0.5, |
108 | | - "VAE tiling should not affect the inference results", |
109 | | - ) |
110 | | - |
111 | | - torch.manual_seed(0) |
112 | | - model.disable_tiling() |
113 | | - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
114 | | - |
115 | | - self.assertEqual( |
116 | | - output_without_tiling.detach().cpu().numpy().all(), |
117 | | - output_without_tiling_2.detach().cpu().numpy().all(), |
118 | | - "Without tiling outputs should match with the outputs when tiling is manually disabled.", |
119 | | - ) |
120 | | - |
121 | | - def test_enable_disable_slicing(self): |
122 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
123 | | - |
124 | | - torch.manual_seed(0) |
125 | | - model = self.model_class(**init_dict).to(torch_device) |
126 | | - |
127 | | - inputs_dict.update({"return_dict": False}) |
128 | | - |
129 | | - torch.manual_seed(0) |
130 | | - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
131 | | - |
132 | | - torch.manual_seed(0) |
133 | | - model.enable_slicing() |
134 | | - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
135 | | - |
136 | | - self.assertLess( |
137 | | - (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), |
138 | | - 0.5, |
139 | | - "VAE slicing should not affect the inference results", |
140 | | - ) |
141 | | - |
142 | | - torch.manual_seed(0) |
143 | | - model.disable_slicing() |
144 | | - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] |
145 | | - |
146 | | - self.assertEqual( |
147 | | - output_without_slicing.detach().cpu().numpy().all(), |
148 | | - output_without_slicing_2.detach().cpu().numpy().all(), |
149 | | - "Without slicing outputs should match with the outputs when slicing is manually disabled.", |
150 | | - ) |
151 | | - |
152 | 87 | def test_gradient_checkpointing_is_applied(self): |
153 | 88 | expected_set = { |
154 | 89 | "HunyuanVideoDecoder3D", |
|
0 commit comments