1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
16+ import pytest
1817import torch
1918
2019from diffusers import Ideogram4Transformer2DModel
2322 LLM_TOKEN_INDICATOR ,
2423 OUTPUT_IMAGE_INDICATOR ,
2524)
25+ from diffusers .utils .torch_utils import randn_tensor
2626
2727from ...testing_utils import enable_full_determinism , torch_device
28- from ..test_modeling_common import ModelTesterMixin
28+ from ..testing_utils import (
29+ AttentionTesterMixin ,
30+ BaseModelTesterConfig ,
31+ MemoryTesterMixin ,
32+ ModelTesterMixin ,
33+ TrainingTesterMixin ,
34+ )
2935
3036
3137enable_full_determinism ()
3238
3339
34- class Ideogram4TransformerTests (ModelTesterMixin , unittest .TestCase ):
35- model_class = Ideogram4Transformer2DModel
36- main_input_name = "hidden_states"
37- model_split_percents = [0.9 , 0.9 , 0.9 ]
38-
40+ class Ideogram4TransformerTesterConfig (BaseModelTesterConfig ):
3941 _hidden_size = 32
4042 _num_heads = 4
4143 _head_dim = _hidden_size // _num_heads # 8
@@ -44,61 +46,32 @@ class Ideogram4TransformerTests(ModelTesterMixin, unittest.TestCase):
4446 _max_text_tokens = 4
4547 _num_image_tokens = 4
4648
47- def prepare_dummy_input (self , height : int = 0 , width : int = 0 ):
48- del height , width
49- batch_size = 1
50- max_text_tokens = self ._max_text_tokens
51- num_image_tokens = self ._num_image_tokens
52- seq_len = max_text_tokens + num_image_tokens
53-
54- hidden_states = torch .zeros (batch_size , seq_len , self ._in_channels )
55- hidden_states [:, max_text_tokens :] = torch .randn (batch_size , num_image_tokens , self ._in_channels )
56-
57- encoder_hidden_states = torch .zeros (batch_size , seq_len , self ._llm_features_dim )
58- encoder_hidden_states [:, :max_text_tokens ] = torch .randn (batch_size , max_text_tokens , self ._llm_features_dim )
49+ @property
50+ def model_class (self ):
51+ return Ideogram4Transformer2DModel
5952
60- position_ids = torch .zeros (batch_size , seq_len , 3 , dtype = torch .long )
61- text_pos = torch .arange (max_text_tokens )
62- position_ids [:, :max_text_tokens , 0 ] = text_pos
63- position_ids [:, :max_text_tokens , 1 ] = text_pos
64- position_ids [:, :max_text_tokens , 2 ] = text_pos
65- # Image tokens get a 2x2 grid with the IMAGE_POSITION_OFFSET applied.
66- image_h = torch .tensor ([0 , 0 , 1 , 1 ])
67- image_w = torch .tensor ([0 , 1 , 0 , 1 ])
68- position_ids [:, max_text_tokens :, 0 ] = IMAGE_POSITION_OFFSET
69- position_ids [:, max_text_tokens :, 1 ] = image_h + IMAGE_POSITION_OFFSET
70- position_ids [:, max_text_tokens :, 2 ] = image_w + IMAGE_POSITION_OFFSET
53+ @property
54+ def main_input_name (self ) -> str :
55+ return "hidden_states"
7156
72- segment_ids = torch .ones (batch_size , seq_len , dtype = torch .long )
73- indicator = torch .empty (batch_size , seq_len , dtype = torch .long )
74- indicator [:, :max_text_tokens ] = LLM_TOKEN_INDICATOR
75- indicator [:, max_text_tokens :] = OUTPUT_IMAGE_INDICATOR
76- timestep = torch .tensor ([0.5 ])
77-
78- inputs = {
79- "hidden_states" : hidden_states .to (torch_device ),
80- "encoder_hidden_states" : encoder_hidden_states .to (torch_device ),
81- "timestep" : timestep .to (torch_device ),
82- "position_ids" : position_ids .to (torch_device ),
83- "segment_ids" : segment_ids .to (torch_device ),
84- "indicator" : indicator .to (torch_device ),
85- }
86- return inputs
57+ @property
58+ def output_shape (self ) -> tuple [int , ...]:
59+ return (1 , self ._max_text_tokens + self ._num_image_tokens , self ._in_channels )
8760
8861 @property
89- def dummy_input (self ):
90- return self .prepare_dummy_input ( )
62+ def input_shape (self ) -> tuple [ int , ...] :
63+ return ( 1 , self ._max_text_tokens + self . _num_image_tokens , self . _in_channels )
9164
9265 @property
93- def input_shape (self ):
94- return ( self . _max_text_tokens + self . _num_image_tokens , self . _in_channels )
66+ def model_split_percents (self ) -> list :
67+ return [ 0.9 , 0.9 , 0.9 ]
9568
9669 @property
97- def output_shape (self ):
98- return ( self . _max_text_tokens + self . _num_image_tokens , self . _in_channels )
70+ def generator (self ):
71+ return torch . Generator ( "cpu" ). manual_seed ( 0 )
9972
100- def prepare_init_args_and_inputs_for_common (self ):
101- init_dict = {
73+ def get_init_dict (self ) -> dict :
74+ return {
10275 "in_channels" : self ._in_channels ,
10376 "num_layers" : 2 ,
10477 "attention_head_dim" : self ._head_dim ,
@@ -110,24 +83,84 @@ def prepare_init_args_and_inputs_for_common(self):
11083 "mrope_section" : (2 , 1 , 1 ),
11184 "norm_eps" : 1e-5 ,
11285 }
113- inputs_dict = self .dummy_input
114- return init_dict , inputs_dict
86+
87+ def get_dummy_inputs (self ) -> dict [str , torch .Tensor ]:
88+ batch_size = 1
89+ max_text_tokens = self ._max_text_tokens
90+ num_image_tokens = self ._num_image_tokens
91+ seq_len = max_text_tokens + num_image_tokens
92+
93+ hidden_states = torch .zeros (
94+ batch_size , seq_len , self ._in_channels , device = torch_device , dtype = self .torch_dtype
95+ )
96+ hidden_states [:, max_text_tokens :] = randn_tensor (
97+ (batch_size , num_image_tokens , self ._in_channels ),
98+ generator = self .generator ,
99+ device = torch_device ,
100+ dtype = self .torch_dtype ,
101+ )
102+
103+ encoder_hidden_states = torch .zeros (
104+ batch_size , seq_len , self ._llm_features_dim , device = torch_device , dtype = self .torch_dtype
105+ )
106+ encoder_hidden_states [:, :max_text_tokens ] = randn_tensor (
107+ (batch_size , max_text_tokens , self ._llm_features_dim ),
108+ generator = self .generator ,
109+ device = torch_device ,
110+ dtype = self .torch_dtype ,
111+ )
112+
113+ position_ids = torch .zeros (batch_size , seq_len , 3 , dtype = torch .long , device = torch_device )
114+ text_pos = torch .arange (max_text_tokens , device = torch_device )
115+ position_ids [:, :max_text_tokens , 0 ] = text_pos
116+ position_ids [:, :max_text_tokens , 1 ] = text_pos
117+ position_ids [:, :max_text_tokens , 2 ] = text_pos
118+ # Image tokens get a 2x2 grid with the IMAGE_POSITION_OFFSET applied.
119+ image_h = torch .tensor ([0 , 0 , 1 , 1 ], device = torch_device )
120+ image_w = torch .tensor ([0 , 1 , 0 , 1 ], device = torch_device )
121+ position_ids [:, max_text_tokens :, 0 ] = IMAGE_POSITION_OFFSET
122+ position_ids [:, max_text_tokens :, 1 ] = image_h + IMAGE_POSITION_OFFSET
123+ position_ids [:, max_text_tokens :, 2 ] = image_w + IMAGE_POSITION_OFFSET
124+
125+ segment_ids = torch .ones (batch_size , seq_len , dtype = torch .long , device = torch_device )
126+ indicator = torch .empty (batch_size , seq_len , dtype = torch .long , device = torch_device )
127+ indicator [:, :max_text_tokens ] = LLM_TOKEN_INDICATOR
128+ indicator [:, max_text_tokens :] = OUTPUT_IMAGE_INDICATOR
129+ timestep = torch .tensor ([0.5 ], device = torch_device , dtype = self .torch_dtype )
130+
131+ return {
132+ "hidden_states" : hidden_states ,
133+ "encoder_hidden_states" : encoder_hidden_states ,
134+ "timestep" : timestep ,
135+ "position_ids" : position_ids ,
136+ "segment_ids" : segment_ids ,
137+ "indicator" : indicator ,
138+ }
139+
140+
141+ class TestIdeogram4Transformer (Ideogram4TransformerTesterConfig , ModelTesterMixin ):
142+ """Core model tests for Ideogram 4 Transformer."""
143+
144+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ], ids = ["fp16" , "bf16" ])
145+ def test_from_save_pretrained_dtype_inference (self , tmp_path , dtype ):
146+ # Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory
147+ # .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any
148+ # meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype
149+ # and test_keep_in_fp32_modules.
150+ pytest .skip ("Tolerance requirements too high for meaningful test" )
151+
152+
153+ class TestIdeogram4TransformerMemory (Ideogram4TransformerTesterConfig , MemoryTesterMixin ):
154+ """Memory optimization tests for Ideogram 4 Transformer."""
155+
156+
157+ class TestIdeogram4TransformerTraining (Ideogram4TransformerTesterConfig , TrainingTesterMixin ):
158+ """Training tests for Ideogram 4 Transformer."""
115159
116160 def test_gradient_checkpointing_is_applied (self ):
117161 expected_set = {"Ideogram4Transformer2DModel" }
118162 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
119163
120- def test_forward_signature (self ):
121- # The model's forward takes packed inputs by position; skip the strict signature check used by the mixin.
122- return
123-
124- def test_output (self ):
125- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
126- model = self .model_class (** init_dict )
127- model .to (torch_device )
128- model .eval ()
129- with torch .no_grad ():
130- output = model (** inputs_dict , return_dict = False )[0 ]
131- expected = (1 , self ._max_text_tokens + self ._num_image_tokens , self ._in_channels )
132- self .assertEqual (tuple (output .shape ), expected )
133- self .assertEqual (output .dtype , torch .float32 )
164+
165+ class TestIdeogram4TransformerAttention (Ideogram4TransformerTesterConfig , AttentionTesterMixin ):
166+ """Attention processor tests for Ideogram 4 Transformer."""
0 commit comments