1- # coding=utf-8
21# Copyright 2025 HuggingFace Inc.
32#
43# Licensed under the Apache License, Version 2.0 (the "License");
1312# See the License for the specific language governing permissions and
1413# limitations under the License.
1514
16- import unittest
17-
1815import torch
1916
2017from diffusers import LTXVideoTransformer3DModel
18+ from diffusers .utils .torch_utils import randn_tensor
2119
2220from ...testing_utils import enable_full_determinism , torch_device
23- from ..test_modeling_common import ModelTesterMixin , TorchCompileTesterMixin
21+ from ..testing_utils import (
22+ BaseModelTesterConfig ,
23+ MemoryTesterMixin ,
24+ ModelTesterMixin ,
25+ TorchCompileTesterMixin ,
26+ TrainingTesterMixin ,
27+ )
2428
2529
2630enable_full_determinism ()
2731
2832
29- class LTXTransformerTests (ModelTesterMixin , unittest .TestCase ):
30- model_class = LTXVideoTransformer3DModel
31- main_input_name = "hidden_states"
32- uses_custom_attn_processor = True
33+ class LTXTransformerTesterConfig (BaseModelTesterConfig ):
34+ @property
35+ def model_class (self ):
36+ return LTXVideoTransformer3DModel
37+
38+ @property
39+ def output_shape (self ) -> tuple [int , int ]:
40+ return (512 , 4 )
3341
3442 @property
35- def dummy_input (self ):
43+ def input_shape (self ) -> tuple [int , int ]:
44+ return (512 , 4 )
45+
46+ @property
47+ def main_input_name (self ) -> str :
48+ return "hidden_states"
49+
50+ @property
51+ def generator (self ):
52+ return torch .Generator ("cpu" ).manual_seed (0 )
53+
54+ def get_init_dict (self ):
55+ return {
56+ "in_channels" : 4 ,
57+ "out_channels" : 4 ,
58+ "num_attention_heads" : 2 ,
59+ "attention_head_dim" : 8 ,
60+ "cross_attention_dim" : 16 ,
61+ "num_layers" : 1 ,
62+ "qk_norm" : "rms_norm_across_heads" ,
63+ "caption_channels" : 16 ,
64+ }
65+
66+ def get_dummy_inputs (self ) -> dict [str , torch .Tensor ]:
3667 batch_size = 2
3768 num_channels = 4
3869 num_frames = 2
@@ -41,50 +72,47 @@ def dummy_input(self):
4172 embedding_dim = 16
4273 sequence_length = 16
4374
44- hidden_states = torch .randn ((batch_size , num_frames * height * width , num_channels )).to (torch_device )
45- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
46- encoder_attention_mask = torch .ones ((batch_size , sequence_length )).bool ().to (torch_device )
47- timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
48-
4975 return {
50- "hidden_states" : hidden_states ,
51- "encoder_hidden_states" : encoder_hidden_states ,
52- "timestep" : timestep ,
53- "encoder_attention_mask" : encoder_attention_mask ,
76+ "hidden_states" : randn_tensor (
77+ (batch_size , num_frames * height * width , num_channels ),
78+ generator = self .generator ,
79+ device = torch_device ,
80+ ),
81+ "encoder_hidden_states" : randn_tensor (
82+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
83+ ),
84+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
85+ "encoder_attention_mask" : torch .ones ((batch_size , sequence_length )).bool ().to (torch_device ),
5486 "num_frames" : num_frames ,
5587 "height" : height ,
5688 "width" : width ,
5789 }
5890
59- @property
60- def input_shape (self ):
61- return (512 , 4 )
6291
63- @property
64- def output_shape (self ):
65- return (512 , 4 )
92+ class TestLTXTransformer (LTXTransformerTesterConfig , ModelTesterMixin ):
93+ """Core model tests for LTX Video Transformer."""
6694
67- def prepare_init_args_and_inputs_for_common (self ):
68- init_dict = {
69- "in_channels" : 4 ,
70- "out_channels" : 4 ,
71- "num_attention_heads" : 2 ,
72- "attention_head_dim" : 8 ,
73- "cross_attention_dim" : 16 ,
74- "num_layers" : 1 ,
75- "qk_norm" : "rms_norm_across_heads" ,
76- "caption_channels" : 16 ,
77- }
78- inputs_dict = self .dummy_input
79- return init_dict , inputs_dict
95+
96+ class TestLTXTransformerMemory (LTXTransformerTesterConfig , MemoryTesterMixin ):
97+ """Memory optimization tests for LTX Video Transformer."""
98+
99+
100+ class TestLTXTransformerTraining (LTXTransformerTesterConfig , TrainingTesterMixin ):
101+ """Training tests for LTX Video Transformer."""
80102
81103 def test_gradient_checkpointing_is_applied (self ):
82- expected_set = {"LTXVideoTransformer3DModel" }
83- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
104+ super ().test_gradient_checkpointing_is_applied (expected_set = {"LTXVideoTransformer3DModel" })
105+
106+
107+ class TestLTXTransformerCompile (LTXTransformerTesterConfig , TorchCompileTesterMixin ):
108+ """Torch compile tests for LTX Video Transformer."""
109+
84110
111+ # TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
112+ # class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin):
113+ # """BitsAndBytes quantization tests for LTX Video Transformer."""
85114
86- class LTXTransformerCompileTests (TorchCompileTesterMixin , unittest .TestCase ):
87- model_class = LTXVideoTransformer3DModel
88115
89- def prepare_init_args_and_inputs_for_common (self ):
90- return LTXTransformerTests ().prepare_init_args_and_inputs_for_common ()
116+ # TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
117+ # class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin):
118+ # """TorchAo quantization tests for LTX Video Transformer."""
0 commit comments