@@ -90,10 +90,7 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
9090 encoder_hidden_states = randn_tensor (
9191 (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
9292 )
93- encoder_hidden_states_mask = torch .ones ((batch_size , sequence_length ), dtype = torch .long , device = torch_device )
94- encoder_hidden_states_mask [:, 1 ] = 0
95- encoder_hidden_states_mask [:, 3 ] = 0
96- encoder_hidden_states_mask [:, 5 :] = 0
93+ encoder_hidden_states_mask = torch .ones ((batch_size , sequence_length )).to (torch_device , torch .long )
9794 timestep = torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size )
9895 orig_height = height * 2 * vae_scale_factor
9996 orig_width = width * 2 * vae_scale_factor
@@ -115,7 +112,7 @@ def test_infers_text_seq_len_from_mask(self, batch_size):
115112 inputs = self .get_dummy_inputs (batch_size = batch_size )
116113 model = self .model_class (** init_dict ).to (torch_device )
117114
118- encoder_hidden_states_mask = torch . ones_like ( inputs ["encoder_hidden_states_mask" ])
115+ encoder_hidden_states_mask = inputs ["encoder_hidden_states_mask" ]. clone ( )
119116 encoder_hidden_states_mask [:, 2 :] = 0
120117
121118 rope_text_seq_len , per_sample_len , normalized_mask = compute_text_seq_len_from_mask (
@@ -159,14 +156,21 @@ def test_non_contiguous_attention_mask(self, batch_size):
159156 inputs = self .get_dummy_inputs (batch_size = batch_size )
160157 model = self .model_class (** init_dict ).to (torch_device )
161158
159+ encoder_hidden_states_mask = inputs ["encoder_hidden_states_mask" ].clone ()
160+ encoder_hidden_states_mask [:, 1 ] = 0
161+ encoder_hidden_states_mask [:, 3 ] = 0
162+ encoder_hidden_states_mask [:, 5 :] = 0
163+
162164 inferred_rope_len , per_sample_len , normalized_mask = compute_text_seq_len_from_mask (
163- inputs ["encoder_hidden_states" ], inputs [ " encoder_hidden_states_mask" ]
165+ inputs ["encoder_hidden_states" ], encoder_hidden_states_mask
164166 )
165167 assert int (per_sample_len .max ().item ()) == 5
166168 assert inferred_rope_len == inputs ["encoder_hidden_states" ].shape [1 ]
167169 assert isinstance (inferred_rope_len , int )
168170 assert normalized_mask .dtype == torch .bool
169171
172+ inputs ["encoder_hidden_states_mask" ] = normalized_mask
173+
170174 with torch .no_grad ():
171175 output = model (** inputs )
172176
@@ -259,6 +263,15 @@ class TestQwenImageTransformerContextParallelAttnBackends(
259263 # _flash_3_hub do not support.
260264 unsupported_attn_backends = ["flash_hub" , "_flash_3_hub" ]
261265
266+ def get_dummy_inputs (self , batch_size : int = 1 ) -> dict [str , torch .Tensor ]:
267+ inputs = super ().get_dummy_inputs (batch_size = batch_size )
268+ encoder_hidden_states_mask = inputs ["encoder_hidden_states_mask" ]
269+ encoder_hidden_states_mask [:, 1 ] = 0
270+ encoder_hidden_states_mask [:, 3 ] = 0
271+ encoder_hidden_states_mask [:, 5 :] = 0
272+ inputs ["encoder_hidden_states_mask" ] = encoder_hidden_states_mask
273+ return inputs
274+
262275
263276class TestQwenImageTransformerLoRA (QwenImageTransformerTesterConfig , LoraTesterMixin ):
264277 """LoRA adapter tests for QwenImage Transformer."""
@@ -363,7 +376,7 @@ def test_torch_compile_with_and_without_mask(self):
363376 assert output_no_mask_2 .sample .shape [1 ] == inputs ["hidden_states" ].shape [1 ]
364377
365378 inputs_all_ones = inputs .copy ()
366- inputs_all_ones ["encoder_hidden_states_mask" ] = torch . ones_like ( inputs [ "encoder_hidden_states_mask" ] )
379+ assert inputs_all_ones ["encoder_hidden_states_mask" ]. all (). item ( )
367380
368381 with torch .no_grad ():
369382 output_all_ones = model (** inputs_all_ones )
0 commit comments