Skip to content

Commit 924cfb2

Browse files
sayakpauldg845
andauthored
[tests] Improve ideogram4 tests (#13862)
* improve ideogram4 tests * fix --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 2fa7f77 commit 924cfb2

2 files changed

Lines changed: 104 additions & 70 deletions

File tree

src/diffusers/models/transformers/transformer_ideogram4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-5) -> None:
136136
self.hidden_size = hidden_size
137137
self.num_heads = num_heads
138138
self.head_dim = hidden_size // num_heads
139+
self.use_bias = False
139140

140141
self.to_q = nn.Linear(hidden_size, hidden_size, bias=False)
141142
self.to_k = nn.Linear(hidden_size, hidden_size, bias=False)

tests/models/transformers/test_models_transformer_ideogram4.py

Lines changed: 103 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
16+
import pytest
1817
import torch
1918

2019
from diffusers import Ideogram4Transformer2DModel
@@ -23,19 +22,22 @@
2322
LLM_TOKEN_INDICATOR,
2423
OUTPUT_IMAGE_INDICATOR,
2524
)
25+
from diffusers.utils.torch_utils import randn_tensor
2626

2727
from ...testing_utils import enable_full_determinism, torch_device
28-
from ..test_modeling_common import ModelTesterMixin
28+
from ..testing_utils import (
29+
AttentionTesterMixin,
30+
BaseModelTesterConfig,
31+
MemoryTesterMixin,
32+
ModelTesterMixin,
33+
TrainingTesterMixin,
34+
)
2935

3036

3137
enable_full_determinism()
3238

3339

34-
class Ideogram4TransformerTests(ModelTesterMixin, unittest.TestCase):
35-
model_class = Ideogram4Transformer2DModel
36-
main_input_name = "hidden_states"
37-
model_split_percents = [0.9, 0.9, 0.9]
38-
40+
class Ideogram4TransformerTesterConfig(BaseModelTesterConfig):
3941
_hidden_size = 32
4042
_num_heads = 4
4143
_head_dim = _hidden_size // _num_heads # 8
@@ -44,61 +46,32 @@ class Ideogram4TransformerTests(ModelTesterMixin, unittest.TestCase):
4446
_max_text_tokens = 4
4547
_num_image_tokens = 4
4648

47-
def prepare_dummy_input(self, height: int = 0, width: int = 0):
48-
del height, width
49-
batch_size = 1
50-
max_text_tokens = self._max_text_tokens
51-
num_image_tokens = self._num_image_tokens
52-
seq_len = max_text_tokens + num_image_tokens
53-
54-
hidden_states = torch.zeros(batch_size, seq_len, self._in_channels)
55-
hidden_states[:, max_text_tokens:] = torch.randn(batch_size, num_image_tokens, self._in_channels)
56-
57-
encoder_hidden_states = torch.zeros(batch_size, seq_len, self._llm_features_dim)
58-
encoder_hidden_states[:, :max_text_tokens] = torch.randn(batch_size, max_text_tokens, self._llm_features_dim)
49+
@property
50+
def model_class(self):
51+
return Ideogram4Transformer2DModel
5952

60-
position_ids = torch.zeros(batch_size, seq_len, 3, dtype=torch.long)
61-
text_pos = torch.arange(max_text_tokens)
62-
position_ids[:, :max_text_tokens, 0] = text_pos
63-
position_ids[:, :max_text_tokens, 1] = text_pos
64-
position_ids[:, :max_text_tokens, 2] = text_pos
65-
# Image tokens get a 2x2 grid with the IMAGE_POSITION_OFFSET applied.
66-
image_h = torch.tensor([0, 0, 1, 1])
67-
image_w = torch.tensor([0, 1, 0, 1])
68-
position_ids[:, max_text_tokens:, 0] = IMAGE_POSITION_OFFSET
69-
position_ids[:, max_text_tokens:, 1] = image_h + IMAGE_POSITION_OFFSET
70-
position_ids[:, max_text_tokens:, 2] = image_w + IMAGE_POSITION_OFFSET
53+
@property
54+
def main_input_name(self) -> str:
55+
return "hidden_states"
7156

72-
segment_ids = torch.ones(batch_size, seq_len, dtype=torch.long)
73-
indicator = torch.empty(batch_size, seq_len, dtype=torch.long)
74-
indicator[:, :max_text_tokens] = LLM_TOKEN_INDICATOR
75-
indicator[:, max_text_tokens:] = OUTPUT_IMAGE_INDICATOR
76-
timestep = torch.tensor([0.5])
77-
78-
inputs = {
79-
"hidden_states": hidden_states.to(torch_device),
80-
"encoder_hidden_states": encoder_hidden_states.to(torch_device),
81-
"timestep": timestep.to(torch_device),
82-
"position_ids": position_ids.to(torch_device),
83-
"segment_ids": segment_ids.to(torch_device),
84-
"indicator": indicator.to(torch_device),
85-
}
86-
return inputs
57+
@property
58+
def output_shape(self) -> tuple[int, ...]:
59+
return (1, self._max_text_tokens + self._num_image_tokens, self._in_channels)
8760

8861
@property
89-
def dummy_input(self):
90-
return self.prepare_dummy_input()
62+
def input_shape(self) -> tuple[int, ...]:
63+
return (1, self._max_text_tokens + self._num_image_tokens, self._in_channels)
9164

9265
@property
93-
def input_shape(self):
94-
return (self._max_text_tokens + self._num_image_tokens, self._in_channels)
66+
def model_split_percents(self) -> list:
67+
return [0.9, 0.9, 0.9]
9568

9669
@property
97-
def output_shape(self):
98-
return (self._max_text_tokens + self._num_image_tokens, self._in_channels)
70+
def generator(self):
71+
return torch.Generator("cpu").manual_seed(0)
9972

100-
def prepare_init_args_and_inputs_for_common(self):
101-
init_dict = {
73+
def get_init_dict(self) -> dict:
74+
return {
10275
"in_channels": self._in_channels,
10376
"num_layers": 2,
10477
"attention_head_dim": self._head_dim,
@@ -110,24 +83,84 @@ def prepare_init_args_and_inputs_for_common(self):
11083
"mrope_section": (2, 1, 1),
11184
"norm_eps": 1e-5,
11285
}
113-
inputs_dict = self.dummy_input
114-
return init_dict, inputs_dict
86+
87+
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
88+
batch_size = 1
89+
max_text_tokens = self._max_text_tokens
90+
num_image_tokens = self._num_image_tokens
91+
seq_len = max_text_tokens + num_image_tokens
92+
93+
hidden_states = torch.zeros(
94+
batch_size, seq_len, self._in_channels, device=torch_device, dtype=self.torch_dtype
95+
)
96+
hidden_states[:, max_text_tokens:] = randn_tensor(
97+
(batch_size, num_image_tokens, self._in_channels),
98+
generator=self.generator,
99+
device=torch_device,
100+
dtype=self.torch_dtype,
101+
)
102+
103+
encoder_hidden_states = torch.zeros(
104+
batch_size, seq_len, self._llm_features_dim, device=torch_device, dtype=self.torch_dtype
105+
)
106+
encoder_hidden_states[:, :max_text_tokens] = randn_tensor(
107+
(batch_size, max_text_tokens, self._llm_features_dim),
108+
generator=self.generator,
109+
device=torch_device,
110+
dtype=self.torch_dtype,
111+
)
112+
113+
position_ids = torch.zeros(batch_size, seq_len, 3, dtype=torch.long, device=torch_device)
114+
text_pos = torch.arange(max_text_tokens, device=torch_device)
115+
position_ids[:, :max_text_tokens, 0] = text_pos
116+
position_ids[:, :max_text_tokens, 1] = text_pos
117+
position_ids[:, :max_text_tokens, 2] = text_pos
118+
# Image tokens get a 2x2 grid with the IMAGE_POSITION_OFFSET applied.
119+
image_h = torch.tensor([0, 0, 1, 1], device=torch_device)
120+
image_w = torch.tensor([0, 1, 0, 1], device=torch_device)
121+
position_ids[:, max_text_tokens:, 0] = IMAGE_POSITION_OFFSET
122+
position_ids[:, max_text_tokens:, 1] = image_h + IMAGE_POSITION_OFFSET
123+
position_ids[:, max_text_tokens:, 2] = image_w + IMAGE_POSITION_OFFSET
124+
125+
segment_ids = torch.ones(batch_size, seq_len, dtype=torch.long, device=torch_device)
126+
indicator = torch.empty(batch_size, seq_len, dtype=torch.long, device=torch_device)
127+
indicator[:, :max_text_tokens] = LLM_TOKEN_INDICATOR
128+
indicator[:, max_text_tokens:] = OUTPUT_IMAGE_INDICATOR
129+
timestep = torch.tensor([0.5], device=torch_device, dtype=self.torch_dtype)
130+
131+
return {
132+
"hidden_states": hidden_states,
133+
"encoder_hidden_states": encoder_hidden_states,
134+
"timestep": timestep,
135+
"position_ids": position_ids,
136+
"segment_ids": segment_ids,
137+
"indicator": indicator,
138+
}
139+
140+
141+
class TestIdeogram4Transformer(Ideogram4TransformerTesterConfig, ModelTesterMixin):
142+
"""Core model tests for Ideogram 4 Transformer."""
143+
144+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
145+
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
146+
# Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory
147+
# .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any
148+
# meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype
149+
# and test_keep_in_fp32_modules.
150+
pytest.skip("Tolerance requirements too high for meaningful test")
151+
152+
153+
class TestIdeogram4TransformerMemory(Ideogram4TransformerTesterConfig, MemoryTesterMixin):
154+
"""Memory optimization tests for Ideogram 4 Transformer."""
155+
156+
157+
class TestIdeogram4TransformerTraining(Ideogram4TransformerTesterConfig, TrainingTesterMixin):
158+
"""Training tests for Ideogram 4 Transformer."""
115159

116160
def test_gradient_checkpointing_is_applied(self):
117161
expected_set = {"Ideogram4Transformer2DModel"}
118162
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
119163

120-
def test_forward_signature(self):
121-
# The model's forward takes packed inputs by position; skip the strict signature check used by the mixin.
122-
return
123-
124-
def test_output(self):
125-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
126-
model = self.model_class(**init_dict)
127-
model.to(torch_device)
128-
model.eval()
129-
with torch.no_grad():
130-
output = model(**inputs_dict, return_dict=False)[0]
131-
expected = (1, self._max_text_tokens + self._num_image_tokens, self._in_channels)
132-
self.assertEqual(tuple(output.shape), expected)
133-
self.assertEqual(output.dtype, torch.float32)
164+
165+
class TestIdeogram4TransformerAttention(Ideogram4TransformerTesterConfig, AttentionTesterMixin):
166+
"""Attention processor tests for Ideogram 4 Transformer."""

0 commit comments

Comments
 (0)