diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index aab593c93643..1db643a60f81 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -888,6 +888,8 @@ class HunyuanVideoTransformer3DModel( _no_split_modules = [ "HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoTokenReplaceTransformerBlock", + "HunyuanVideoTokenReplaceSingleTransformerBlock", "HunyuanVideoPatchEmbed", "HunyuanVideoTokenRefiner", ] diff --git a/tests/models/transformers/test_models_transformer_hunyuan_1_5.py b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py index 57080bc5b0b4..02eec91a1db5 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_1_5.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py @@ -12,71 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import HunyuanVideo15Transformer3DModel +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + ModelTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanVideo15Transformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True - model_split_percents = [0.99, 0.99, 0.99] - +class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig): text_embed_dim = 16 text_embed_2_dim = 8 image_embed_dim = 12 @property - def dummy_input(self): - batch_size = 1 - num_channels = 4 - num_frames = 1 - height = 8 - width = 8 - sequence_length = 6 - sequence_length_2 = 4 - image_sequence_length = 3 + def model_class(self): + return HunyuanVideo15Transformer3DModel - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device) - encoder_hidden_states_2 = torch.randn( - (batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device - ) - encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device) - encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device) - # All zeros for inducing T2V path in the model. - image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device) + @property + def main_input_name(self) -> str: + return "hidden_states" - return { - "hidden_states": hidden_states, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - "encoder_attention_mask": encoder_attention_mask, - "encoder_hidden_states_2": encoder_hidden_states_2, - "encoder_attention_mask_2": encoder_attention_mask_2, - "image_embeds": image_embeds, - } + @property + def model_split_percents(self) -> list: + return [0.99, 0.99, 0.99] @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 1, 8, 8) @property - def output_shape(self): + def input_shape(self) -> tuple: return (4, 1, 8, 8) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { "in_channels": 4, "out_channels": 4, "num_attention_heads": 2, @@ -93,9 +75,40 @@ def prepare_init_args_and_inputs_for_common(self): "target_size": 16, "task_type": "t2v", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_channels = 4 + num_frames = 1 + height = 8 + width = 8 + sequence_length = 6 + sequence_length_2 = 4 + image_sequence_length = 3 + + return { + "hidden_states": randn_tensor( + (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device + ), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, self.text_embed_dim), generator=self.generator, device=torch_device + ), + "encoder_hidden_states_2": randn_tensor( + (batch_size, sequence_length_2, self.text_embed_2_dim), generator=self.generator, device=torch_device + ), + "encoder_attention_mask": torch.ones((batch_size, sequence_length), device=torch_device), + "encoder_attention_mask_2": torch.ones((batch_size, sequence_length_2), device=torch_device), + "image_embeds": torch.zeros( + (batch_size, image_sequence_length, self.image_embed_dim), device=torch_device + ), + } + + +class TestHunyuanVideo15Transformer(HunyuanVideo15TransformerTesterConfig, ModelTesterMixin): + pass + + +class TestHunyuanVideo15TransformerTraining(HunyuanVideo15TransformerTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"HunyuanVideo15Transformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index d82a62d58ec3..1c08244b620c 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -13,51 +13,97 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import HunyuanDiT2DModel +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import ( - enable_full_determinism, - torch_device, +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + BaseModelTesterConfig, + ModelTesterMixin, + TrainingTesterMixin, ) -from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanDiT2DModel - main_input_name = "hidden_states" +class HunyuanDiTTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return HunyuanDiT2DModel + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-hunyuan-dit-pipe" + + @property + def pretrained_model_kwargs(self): + return {"subfolder": "transformer"} + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def output_shape(self) -> tuple: + return (8, 8, 8) + + @property + def input_shape(self) -> tuple: + return (4, 8, 8) @property - def dummy_input(self): - batch_size = 2 + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { + "sample_size": 8, + "patch_size": 2, + "in_channels": 4, + "num_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "cross_attention_dim": 8, + "cross_attention_dim_t5": 8, + "pooled_projection_dim": 4, + "hidden_size": 16, + "text_len": 4, + "text_len_t5": 4, + "activation_fn": "gelu-approximate", + } + + def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]: num_channels = 4 height = width = 8 embedding_dim = 8 sequence_length = 4 sequence_length_t5 = 4 - hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + hidden_states = randn_tensor( + (batch_size, num_channels, height, width), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device) - encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device) + encoder_hidden_states_t5 = randn_tensor( + (batch_size, sequence_length_t5, embedding_dim), generator=self.generator, device=torch_device + ) text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).float().to(torch_device) original_size = [1024, 1024] target_size = [16, 16] crops_coords_top_left = [0, 0] add_time_ids = list(original_size + target_size + crops_coords_top_left) - add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device) + add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=torch.float32).to(torch_device) style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device) image_rotary_emb = [ - torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype), - torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype), + torch.ones(size=(1, 8), dtype=torch.float32), + torch.zeros(size=(1, 8), dtype=torch.float32), ] return { @@ -72,42 +118,14 @@ def dummy_input(self): "image_rotary_emb": image_rotary_emb, } - @property - def input_shape(self): - return (4, 8, 8) - - @property - def output_shape(self): - return (8, 8, 8) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "sample_size": 8, - "patch_size": 2, - "in_channels": 4, - "num_layers": 1, - "attention_head_dim": 8, - "num_attention_heads": 2, - "cross_attention_dim": 8, - "cross_attention_dim_t5": 8, - "pooled_projection_dim": 4, - "hidden_size": 16, - "text_len": 4, - "text_len_t5": 4, - "activation_fn": "gelu-approximate", - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict +class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin): def test_output(self): - super().test_output( - expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape - ) + batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0] + super().test_output(expected_output_shape=(batch_size,) + self.output_shape) - @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") - def test_set_xformers_attn_processor_for_determinism(self): - pass - @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") - def test_set_attn_processor_for_determinism(self): - pass +class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanDiT2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index 385a5eefd58b..90c716a336a5 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -12,64 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import HunyuanVideoTransformer3DModel - -from ...testing_utils import ( - enable_full_determinism, - torch_device, +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, ) -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True +# ======================== HunyuanVideo Text-to-Video ======================== + +class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 1 - num_channels = 4 - num_frames = 1 - height = 16 - width = 16 - text_encoder_embedding_dim = 16 - pooled_projection_dim = 8 - sequence_length = 12 + def model_class(self): + return HunyuanVideoTransformer3DModel - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) - pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) - guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-random-hunyuanvideo" - return { - "hidden_states": hidden_states, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_projections, - "encoder_attention_mask": encoder_attention_mask, - "guidance": guidance, - } + @property + def pretrained_model_kwargs(self): + return {"subfolder": "transformer"} + + @property + def main_input_name(self) -> str: + return "hidden_states" @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 1, 16, 16) @property - def output_shape(self): + def input_shape(self) -> tuple: return (4, 1, 16, 16) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { "in_channels": 4, "out_channels": 4, "num_attention_heads": 2, @@ -85,136 +80,106 @@ def prepare_init_args_and_inputs_for_common(self): "rope_axes_dim": (2, 4, 4), "image_condition_type": None, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"HunyuanVideoTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - -class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel - - def prepare_init_args_and_inputs_for_common(self): - return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() - - -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True @property - def dummy_input(self): - batch_size = 1 - num_channels = 8 + def torch_dtype(self): + return None + + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_channels = 4 num_frames = 1 height = 16 width = 16 text_encoder_embedding_dim = 16 pooled_projection_dim = 8 sequence_length = 12 - - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) - pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) - guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + dtype = self.torch_dtype return { - "hidden_states": hidden_states, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_projections, - "encoder_attention_mask": encoder_attention_mask, - "guidance": guidance, + "hidden_states": randn_tensor( + (batch_size, num_channels, num_frames, height, width), + generator=self.generator, + device=torch_device, + dtype=dtype, + ), + "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to( + torch_device, dtype=dtype or torch.float32 + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, text_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + dtype=dtype, + ), + "pooled_projections": randn_tensor( + (batch_size, pooled_projection_dim), + generator=self.generator, + device=torch_device, + dtype=dtype, + ), + "encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device), + "guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to( + torch_device, dtype=dtype or torch.float32 + ), } - @property - def input_shape(self): - return (8, 1, 16, 16) - - @property - def output_shape(self): - return (4, 1, 16, 16) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "in_channels": 8, - "out_channels": 4, - "num_attention_heads": 2, - "attention_head_dim": 10, - "num_layers": 1, - "num_single_layers": 1, - "num_refiner_layers": 1, - "patch_size": 1, - "patch_size_t": 1, - "guidance_embeds": True, - "text_embed_dim": 16, - "pooled_projection_dim": 8, - "rope_axes_dim": (2, 4, 4), - "image_condition_type": None, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict +class TestHunyuanVideoTransformer(HunyuanVideoTransformerTesterConfig, ModelTesterMixin): + pass - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) +class TestHunyuanVideoTransformerTraining(HunyuanVideoTransformerTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"HunyuanVideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel +class TestHunyuanVideoTransformerCompile(HunyuanVideoTransformerTesterConfig, TorchCompileTesterMixin): + pass - def prepare_init_args_and_inputs_for_common(self): - return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() +class TestHunyuanVideoTransformerBitsAndBytes(HunyuanVideoTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for HunyuanVideo Transformer.""" + + @property + def torch_dtype(self): + return torch.float16 -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True + +class TestHunyuanVideoTransformerTorchAo(HunyuanVideoTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for HunyuanVideo Transformer.""" @property - def dummy_input(self): - batch_size = 1 - num_channels = 2 * 4 + 1 - num_frames = 1 - height = 16 - width = 16 - text_encoder_embedding_dim = 16 - pooled_projection_dim = 8 - sequence_length = 12 + def torch_dtype(self): + return torch.bfloat16 - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) - pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) - return { - "hidden_states": hidden_states, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_projections, - "encoder_attention_mask": encoder_attention_mask, - } +# ======================== HunyuanVideo Image-to-Video (Latent Concat) ======================== + +class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig): @property - def input_shape(self): - return (8, 1, 16, 16) + def model_class(self): + return HunyuanVideoTransformer3DModel + + @property + def main_input_name(self) -> str: + return "hidden_states" @property - def output_shape(self): + def output_shape(self) -> tuple: return (4, 1, 16, 16) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def input_shape(self) -> tuple: + return (8, 1, 16, 16) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { "in_channels": 2 * 4 + 1, "out_channels": 4, "num_attention_heads": 2, @@ -230,66 +195,64 @@ def prepare_init_args_and_inputs_for_common(self): "rope_axes_dim": (2, 4, 4), "image_condition_type": "latent_concat", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_channels = 2 * 4 + 1 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 - def test_gradient_checkpointing_is_applied(self): - expected_set = {"HunyuanVideoTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + return { + "hidden_states": randn_tensor( + (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device + ), + "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, text_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + ), + "pooled_projections": randn_tensor( + (batch_size, pooled_projection_dim), generator=self.generator, device=torch_device + ), + "encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device), + } -class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel +class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin): + def test_output(self): + super().test_output(expected_output_shape=(1, *self.output_shape)) - def prepare_init_args_and_inputs_for_common(self): - return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() +# ======================== HunyuanVideo Token Replace Image-to-Video ======================== -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True +class HunyuanVideoTokenReplaceTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 1 - num_channels = 2 - num_frames = 1 - height = 16 - width = 16 - text_encoder_embedding_dim = 16 - pooled_projection_dim = 8 - sequence_length = 12 + def model_class(self): + return HunyuanVideoTransformer3DModel - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) - pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) - guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + @property + def main_input_name(self) -> str: + return "hidden_states" - return { - "hidden_states": hidden_states, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_projections, - "encoder_attention_mask": encoder_attention_mask, - "guidance": guidance, - } + @property + def output_shape(self) -> tuple: + return (4, 1, 16, 16) @property - def input_shape(self): + def input_shape(self) -> tuple: return (8, 1, 16, 16) @property - def output_shape(self): - return (4, 1, 16, 16) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "in_channels": 2, "out_channels": 4, "num_attention_heads": 2, @@ -305,19 +268,36 @@ def prepare_init_args_and_inputs_for_common(self): "rope_axes_dim": (2, 4, 4), "image_condition_type": "token_replace", } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) - def test_gradient_checkpointing_is_applied(self): - expected_set = {"HunyuanVideoTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_channels = 2 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + return { + "hidden_states": randn_tensor( + (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device + ), + "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, text_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + ), + "pooled_projections": randn_tensor( + (batch_size, pooled_projection_dim), generator=self.generator, device=torch_device + ), + "encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device), + "guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to( + torch_device, dtype=torch.float32 + ), + } -class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = HunyuanVideoTransformer3DModel - def prepare_init_args_and_inputs_for_common(self): - return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() +class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin): + def test_output(self): + super().test_output(expected_output_shape=(1, *self.output_shape)) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py index 00a2b27e02b6..272b7145326d 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py @@ -12,84 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import HunyuanVideoFramepackTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import ( - enable_full_determinism, - torch_device, +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + BaseModelTesterConfig, + ModelTesterMixin, + TrainingTesterMixin, ) -from ..test_modeling_common import ModelTesterMixin enable_full_determinism() -class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): - model_class = HunyuanVideoFramepackTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True - model_split_percents = [0.5, 0.7, 0.9] - +class HunyuanVideoFramepackTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 1 - num_channels = 4 - num_frames = 3 - height = 4 - width = 4 - text_encoder_embedding_dim = 16 - image_encoder_embedding_dim = 16 - pooled_projection_dim = 8 - sequence_length = 12 + def model_class(self): + return HunyuanVideoFramepackTransformer3DModel - hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) - pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) - image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device) - indices_latents = torch.ones((3,)).to(torch_device) - latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) - indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device) - latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device) - indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device) - latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to( - torch_device - ) - indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + @property + def main_input_name(self) -> str: + return "hidden_states" - return { - "hidden_states": hidden_states, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_projections, - "encoder_attention_mask": encoder_attention_mask, - "guidance": guidance, - "image_embeds": image_embeds, - "indices_latents": indices_latents, - "latents_clean": latents_clean, - "indices_latents_clean": indices_latents_clean, - "latents_history_2x": latents_history_2x, - "indices_latents_history_2x": indices_latents_history_2x, - "latents_history_4x": latents_history_4x, - "indices_latents_history_4x": indices_latents_history_4x, - } + @property + def model_split_percents(self) -> list: + return [0.5, 0.7, 0.9] @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 3, 4, 4) @property - def output_shape(self): + def input_shape(self) -> tuple: return (4, 3, 4, 4) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { "in_channels": 4, "out_channels": 4, "num_attention_heads": 2, @@ -108,9 +73,64 @@ def prepare_init_args_and_inputs_for_common(self): "image_proj_dim": 16, "has_clean_x_embedder": True, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_channels = 4 + num_frames = 3 + height = 4 + width = 4 + text_encoder_embedding_dim = 16 + image_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + return { + "hidden_states": randn_tensor( + (batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device + ), + "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, text_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + ), + "pooled_projections": randn_tensor( + (batch_size, pooled_projection_dim), generator=self.generator, device=torch_device + ), + "encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device), + "guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), + "image_embeds": randn_tensor( + (batch_size, sequence_length, image_encoder_embedding_dim), + generator=self.generator, + device=torch_device, + ), + "indices_latents": torch.ones((num_frames,)).to(torch_device), + "latents_clean": randn_tensor( + (batch_size, num_channels, num_frames - 1, height, width), + generator=self.generator, + device=torch_device, + ), + "indices_latents_clean": torch.ones((num_frames - 1,)).to(torch_device), + "latents_history_2x": randn_tensor( + (batch_size, num_channels, num_frames - 1, height, width), + generator=self.generator, + device=torch_device, + ), + "indices_latents_history_2x": torch.ones((num_frames - 1,)).to(torch_device), + "latents_history_4x": randn_tensor( + (batch_size, num_channels, (num_frames - 1) * 4, height, width), + generator=self.generator, + device=torch_device, + ), + "indices_latents_history_4x": torch.ones(((num_frames - 1) * 4,)).to(torch_device), + } + + +class TestHunyuanVideoFramepackTransformer(HunyuanVideoFramepackTransformerTesterConfig, ModelTesterMixin): + pass + + +class TestHunyuanVideoFramepackTransformerTraining(HunyuanVideoFramepackTransformerTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"HunyuanVideoFramepackTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set)