Skip to content

Commit 3b80bcf

Browse files
DN6sayakpaul
andauthored
[CI] Refactor LTX Transformer Tests (#13254)
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent c33abfc commit 3b80bcf

2 files changed

Lines changed: 171 additions & 214 deletions

File tree

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2025 HuggingFace Inc.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,26 +12,58 @@
1312
# See the License for the specific language governing permissions and
1413
# limitations under the License.
1514

16-
import unittest
17-
1815
import torch
1916

2017
from diffusers import LTXVideoTransformer3DModel
18+
from diffusers.utils.torch_utils import randn_tensor
2119

2220
from ...testing_utils import enable_full_determinism, torch_device
23-
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
21+
from ..testing_utils import (
22+
BaseModelTesterConfig,
23+
MemoryTesterMixin,
24+
ModelTesterMixin,
25+
TorchCompileTesterMixin,
26+
TrainingTesterMixin,
27+
)
2428

2529

2630
enable_full_determinism()
2731

2832

29-
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
30-
model_class = LTXVideoTransformer3DModel
31-
main_input_name = "hidden_states"
32-
uses_custom_attn_processor = True
33+
class LTXTransformerTesterConfig(BaseModelTesterConfig):
34+
@property
35+
def model_class(self):
36+
return LTXVideoTransformer3DModel
37+
38+
@property
39+
def output_shape(self) -> tuple[int, int]:
40+
return (512, 4)
3341

3442
@property
35-
def dummy_input(self):
43+
def input_shape(self) -> tuple[int, int]:
44+
return (512, 4)
45+
46+
@property
47+
def main_input_name(self) -> str:
48+
return "hidden_states"
49+
50+
@property
51+
def generator(self):
52+
return torch.Generator("cpu").manual_seed(0)
53+
54+
def get_init_dict(self):
55+
return {
56+
"in_channels": 4,
57+
"out_channels": 4,
58+
"num_attention_heads": 2,
59+
"attention_head_dim": 8,
60+
"cross_attention_dim": 16,
61+
"num_layers": 1,
62+
"qk_norm": "rms_norm_across_heads",
63+
"caption_channels": 16,
64+
}
65+
66+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
3667
batch_size = 2
3768
num_channels = 4
3869
num_frames = 2
@@ -41,50 +72,47 @@ def dummy_input(self):
4172
embedding_dim = 16
4273
sequence_length = 16
4374

44-
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
45-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
46-
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
47-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
48-
4975
return {
50-
"hidden_states": hidden_states,
51-
"encoder_hidden_states": encoder_hidden_states,
52-
"timestep": timestep,
53-
"encoder_attention_mask": encoder_attention_mask,
76+
"hidden_states": randn_tensor(
77+
(batch_size, num_frames * height * width, num_channels),
78+
generator=self.generator,
79+
device=torch_device,
80+
),
81+
"encoder_hidden_states": randn_tensor(
82+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
83+
),
84+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
85+
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
5486
"num_frames": num_frames,
5587
"height": height,
5688
"width": width,
5789
}
5890

59-
@property
60-
def input_shape(self):
61-
return (512, 4)
6291

63-
@property
64-
def output_shape(self):
65-
return (512, 4)
92+
class TestLTXTransformer(LTXTransformerTesterConfig, ModelTesterMixin):
93+
"""Core model tests for LTX Video Transformer."""
6694

67-
def prepare_init_args_and_inputs_for_common(self):
68-
init_dict = {
69-
"in_channels": 4,
70-
"out_channels": 4,
71-
"num_attention_heads": 2,
72-
"attention_head_dim": 8,
73-
"cross_attention_dim": 16,
74-
"num_layers": 1,
75-
"qk_norm": "rms_norm_across_heads",
76-
"caption_channels": 16,
77-
}
78-
inputs_dict = self.dummy_input
79-
return init_dict, inputs_dict
95+
96+
class TestLTXTransformerMemory(LTXTransformerTesterConfig, MemoryTesterMixin):
97+
"""Memory optimization tests for LTX Video Transformer."""
98+
99+
100+
class TestLTXTransformerTraining(LTXTransformerTesterConfig, TrainingTesterMixin):
101+
"""Training tests for LTX Video Transformer."""
80102

81103
def test_gradient_checkpointing_is_applied(self):
82-
expected_set = {"LTXVideoTransformer3DModel"}
83-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
104+
super().test_gradient_checkpointing_is_applied(expected_set={"LTXVideoTransformer3DModel"})
105+
106+
107+
class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin):
108+
"""Torch compile tests for LTX Video Transformer."""
109+
84110

111+
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
112+
# class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin):
113+
# """BitsAndBytes quantization tests for LTX Video Transformer."""
85114

86-
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
87-
model_class = LTXVideoTransformer3DModel
88115

89-
def prepare_init_args_and_inputs_for_common(self):
90-
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
116+
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
117+
# class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin):
118+
# """TorchAo quantization tests for LTX Video Transformer."""

0 commit comments

Comments
 (0)