Skip to content

Commit 8be05cd

Browse files
committed
revert and update
1 parent 52bc7a3 commit 8be05cd

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,7 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
9090
encoder_hidden_states = randn_tensor(
9191
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
9292
)
93-
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
94-
encoder_hidden_states_mask[:, 1] = 0
95-
encoder_hidden_states_mask[:, 3] = 0
96-
encoder_hidden_states_mask[:, 5:] = 0
93+
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
9794
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
9895
orig_height = height * 2 * vae_scale_factor
9996
orig_width = width * 2 * vae_scale_factor
@@ -115,7 +112,7 @@ def test_infers_text_seq_len_from_mask(self, batch_size):
115112
inputs = self.get_dummy_inputs(batch_size=batch_size)
116113
model = self.model_class(**init_dict).to(torch_device)
117114

118-
encoder_hidden_states_mask = torch.ones_like(inputs["encoder_hidden_states_mask"])
115+
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
119116
encoder_hidden_states_mask[:, 2:] = 0
120117

121118
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
@@ -159,14 +156,21 @@ def test_non_contiguous_attention_mask(self, batch_size):
159156
inputs = self.get_dummy_inputs(batch_size=batch_size)
160157
model = self.model_class(**init_dict).to(torch_device)
161158

159+
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
160+
encoder_hidden_states_mask[:, 1] = 0
161+
encoder_hidden_states_mask[:, 3] = 0
162+
encoder_hidden_states_mask[:, 5:] = 0
163+
162164
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
163-
inputs["encoder_hidden_states"], inputs["encoder_hidden_states_mask"]
165+
inputs["encoder_hidden_states"], encoder_hidden_states_mask
164166
)
165167
assert int(per_sample_len.max().item()) == 5
166168
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
167169
assert isinstance(inferred_rope_len, int)
168170
assert normalized_mask.dtype == torch.bool
169171

172+
inputs["encoder_hidden_states_mask"] = normalized_mask
173+
170174
with torch.no_grad():
171175
output = model(**inputs)
172176

@@ -259,6 +263,15 @@ class TestQwenImageTransformerContextParallelAttnBackends(
259263
# _flash_3_hub do not support.
260264
unsupported_attn_backends = ["flash_hub", "_flash_3_hub"]
261265

266+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
267+
inputs = super().get_dummy_inputs(batch_size=batch_size)
268+
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"]
269+
encoder_hidden_states_mask[:, 1] = 0
270+
encoder_hidden_states_mask[:, 3] = 0
271+
encoder_hidden_states_mask[:, 5:] = 0
272+
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
273+
return inputs
274+
262275

263276
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
264277
"""LoRA adapter tests for QwenImage Transformer."""
@@ -363,7 +376,7 @@ def test_torch_compile_with_and_without_mask(self):
363376
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
364377

365378
inputs_all_ones = inputs.copy()
366-
inputs_all_ones["encoder_hidden_states_mask"] = torch.ones_like(inputs["encoder_hidden_states_mask"])
379+
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
367380

368381
with torch.no_grad():
369382
output_all_ones = model(**inputs_all_ones)

0 commit comments

Comments
 (0)