Skip to content

Commit 9c83ab3

Browse files
committed
remove big mask
1 parent a8036a9 commit 9c83ab3

2 files changed

Lines changed: 37 additions & 3 deletions

File tree

_unittests/ut_tasks/try_export.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def test_qwen25_vli_visual(self):
5252
.. code-block:: bash
5353
5454
NEVERTEST=1 \\
55-
QWEN25ATTENTION=BIGMASK \\
5655
PRETRAINED=1 \\
5756
TESTDEVICE=cuda \\
5857
TESTDTYPE=float16 \\
@@ -164,9 +163,11 @@ def _config_reduction(config, task):
164163
if qwen25_attention:
165164
attention_options = [qwen25_attention]
166165
elif device == "cuda" and dtype in ("float16", "bfloat16"):
167-
attention_options = ["PACKED", "BIGMASK"]
166+
attention_options = [
167+
"PACKED",
168+
]
168169
else:
169-
attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"]
170+
attention_options = ["LOOPMHA", "LOOPA24"]
170171

171172
# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
172173
for attention in attention_options:

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,39 @@ def qwen_sdpa_attention(
200200
scaling: float = 0,
201201
num_heads: int = 16,
202202
) -> torch.Tensor:
203+
"""
204+
The loop can be removed with the following code
205+
but it hits memory overflow for big inputs.
206+
207+
.. code-block:: python
208+
209+
# make square mask
210+
indices = torch.arange(
211+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
212+
)
213+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
214+
cu_seqlens.dtype
215+
)
216+
dot = dot.sum(dim=0)
217+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
218+
bool_mask = mask == 0
219+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
220+
221+
torch._check(bool_mask.shape[2] == key_states.shape[2])
222+
torch._check(bool_mask.shape[3] == key_states.shape[2])
223+
224+
attn_output, _ = attention_interface(
225+
self,
226+
query_states,
227+
key_states,
228+
value_states,
229+
attention_mask=bool_mask,
230+
scaling=self.scaling,
231+
dropout=0.0 if not self.training else self.attention_dropout,
232+
is_causal=False,
233+
**kwargs,
234+
)
235+
"""
203236
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
204237
splits = [
205238
torch.split(tensor, lengths.tolist(), dim=2)

0 commit comments

Comments
 (0)