File tree Expand file tree Collapse file tree
onnx_diagnostic/torch_export_patches/patches Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -52,7 +52,6 @@ def test_qwen25_vli_visual(self):
5252 .. code-block:: bash
5353
5454 NEVERTEST=1 \\
55- QWEN25ATTENTION=BIGMASK \\
5655 PRETRAINED=1 \\
5756 TESTDEVICE=cuda \\
5857 TESTDTYPE=float16 \\
@@ -164,9 +163,11 @@ def _config_reduction(config, task):
164163 if qwen25_attention :
165164 attention_options = [qwen25_attention ]
166165 elif device == "cuda" and dtype in ("float16" , "bfloat16" ):
167- attention_options = ["PACKED" , "BIGMASK" ]
166+ attention_options = [
167+ "PACKED" ,
168+ ]
168169 else :
169- attention_options = ["LOOPMHA" , "LOOPA24" , "BIGMASK" ]
170+ attention_options = ["LOOPMHA" , "LOOPA24" ]
170171
171172 # fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
172173 for attention in attention_options :
Original file line number Diff line number Diff line change @@ -200,6 +200,39 @@ def qwen_sdpa_attention(
200200 scaling : float = 0 ,
201201 num_heads : int = 16 ,
202202 ) -> torch .Tensor :
203+ """
204+ The loop can be removed with the following code
205+ but it hits memory overflow for big inputs.
206+
207+ .. code-block:: python
208+
209+ # make square mask
210+ indices = torch.arange(
211+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
212+ )
213+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
214+ cu_seqlens.dtype
215+ )
216+ dot = dot.sum(dim=0)
217+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
218+ bool_mask = mask == 0
219+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
220+
221+ torch._check(bool_mask.shape[2] == key_states.shape[2])
222+ torch._check(bool_mask.shape[3] == key_states.shape[2])
223+
224+ attn_output, _ = attention_interface(
225+ self,
226+ query_states,
227+ key_states,
228+ value_states,
229+ attention_mask=bool_mask,
230+ scaling=self.scaling,
231+ dropout=0.0 if not self.training else self.attention_dropout,
232+ is_causal=False,
233+ **kwargs,
234+ )
235+ """
203236 lengths = cu_seqlens [1 :] - cu_seqlens [:- 1 ]
204237 splits = [
205238 torch .split (tensor , lengths .tolist (), dim = 2 )
You can’t perform that action at this time.
0 commit comments