-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[CI] Hunyuan Transformer Tests Refactor #13342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+379
−346
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,51 +13,97 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from diffusers import HunyuanDiT2DModel | ||
| from diffusers.utils.torch_utils import randn_tensor | ||
|
|
||
| from ...testing_utils import ( | ||
| enable_full_determinism, | ||
| torch_device, | ||
| from ...testing_utils import enable_full_determinism, torch_device | ||
| from ..testing_utils import ( | ||
| BaseModelTesterConfig, | ||
| ModelTesterMixin, | ||
| TrainingTesterMixin, | ||
|
Comment on lines
+22
to
+25
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need compilation, bitsandbytes, and torchao tests in this one -- the model isn't used that much anyway? |
||
| ) | ||
| from ..test_modeling_common import ModelTesterMixin | ||
|
|
||
|
|
||
| enable_full_determinism() | ||
|
|
||
|
|
||
| class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase): | ||
| model_class = HunyuanDiT2DModel | ||
| main_input_name = "hidden_states" | ||
| class HunyuanDiTTesterConfig(BaseModelTesterConfig): | ||
| @property | ||
| def model_class(self): | ||
| return HunyuanDiT2DModel | ||
|
|
||
| @property | ||
| def pretrained_model_name_or_path(self): | ||
| return "hf-internal-testing/tiny-hunyuan-dit-pipe" | ||
|
|
||
| @property | ||
| def pretrained_model_kwargs(self): | ||
| return {"subfolder": "transformer"} | ||
|
|
||
| @property | ||
| def main_input_name(self) -> str: | ||
| return "hidden_states" | ||
|
|
||
| @property | ||
| def output_shape(self) -> tuple: | ||
| return (8, 8, 8) | ||
|
|
||
| @property | ||
| def input_shape(self) -> tuple: | ||
| return (4, 8, 8) | ||
|
|
||
| @property | ||
| def dummy_input(self): | ||
| batch_size = 2 | ||
| def generator(self): | ||
| return torch.Generator("cpu").manual_seed(0) | ||
|
|
||
| def get_init_dict(self) -> dict: | ||
| return { | ||
| "sample_size": 8, | ||
| "patch_size": 2, | ||
| "in_channels": 4, | ||
| "num_layers": 1, | ||
| "attention_head_dim": 8, | ||
| "num_attention_heads": 2, | ||
| "cross_attention_dim": 8, | ||
| "cross_attention_dim_t5": 8, | ||
| "pooled_projection_dim": 4, | ||
| "hidden_size": 16, | ||
| "text_len": 4, | ||
| "text_len_t5": 4, | ||
| "activation_fn": "gelu-approximate", | ||
| } | ||
|
|
||
| def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]: | ||
| num_channels = 4 | ||
| height = width = 8 | ||
| embedding_dim = 8 | ||
| sequence_length = 4 | ||
| sequence_length_t5 = 4 | ||
|
|
||
| hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) | ||
| encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) | ||
| hidden_states = randn_tensor( | ||
| (batch_size, num_channels, height, width), generator=self.generator, device=torch_device | ||
| ) | ||
| encoder_hidden_states = randn_tensor( | ||
| (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device | ||
| ) | ||
| text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device) | ||
| encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device) | ||
| encoder_hidden_states_t5 = randn_tensor( | ||
| (batch_size, sequence_length_t5, embedding_dim), generator=self.generator, device=torch_device | ||
| ) | ||
| text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device) | ||
| timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device) | ||
| timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).float().to(torch_device) | ||
|
|
||
| original_size = [1024, 1024] | ||
| target_size = [16, 16] | ||
| crops_coords_top_left = [0, 0] | ||
| add_time_ids = list(original_size + target_size + crops_coords_top_left) | ||
| add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device) | ||
| add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=torch.float32).to(torch_device) | ||
| style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device) | ||
| image_rotary_emb = [ | ||
| torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype), | ||
| torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype), | ||
| torch.ones(size=(1, 8), dtype=torch.float32), | ||
| torch.zeros(size=(1, 8), dtype=torch.float32), | ||
| ] | ||
|
|
||
| return { | ||
|
|
@@ -72,42 +118,14 @@ def dummy_input(self): | |
| "image_rotary_emb": image_rotary_emb, | ||
| } | ||
|
|
||
| @property | ||
| def input_shape(self): | ||
| return (4, 8, 8) | ||
|
|
||
| @property | ||
| def output_shape(self): | ||
| return (8, 8, 8) | ||
|
|
||
| def prepare_init_args_and_inputs_for_common(self): | ||
| init_dict = { | ||
| "sample_size": 8, | ||
| "patch_size": 2, | ||
| "in_channels": 4, | ||
| "num_layers": 1, | ||
| "attention_head_dim": 8, | ||
| "num_attention_heads": 2, | ||
| "cross_attention_dim": 8, | ||
| "cross_attention_dim_t5": 8, | ||
| "pooled_projection_dim": 4, | ||
| "hidden_size": 16, | ||
| "text_len": 4, | ||
| "text_len_t5": 4, | ||
| "activation_fn": "gelu-approximate", | ||
| } | ||
| inputs_dict = self.dummy_input | ||
| return init_dict, inputs_dict | ||
|
|
||
| class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin): | ||
| def test_output(self): | ||
| super().test_output( | ||
| expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape | ||
| ) | ||
| batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0] | ||
| super().test_output(expected_output_shape=(batch_size,) + self.output_shape) | ||
|
|
||
| @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") | ||
| def test_set_xformers_attn_processor_for_determinism(self): | ||
| pass | ||
|
|
||
| @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") | ||
| def test_set_attn_processor_for_determinism(self): | ||
| pass | ||
| class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin): | ||
| def test_gradient_checkpointing_is_applied(self): | ||
| expected_set = {"HunyuanDiT2DModel"} | ||
| super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.