Skip to content

Commit a5f2870

Browse files
authored
check patch for sdpa_mask_recent_torch even if it was removed in transformers>=5.0 (#346)
* add test to check a function for old version of transformers * doc
1 parent fdfdfcd commit a5f2870

4 files changed

Lines changed: 95 additions & 1 deletion

File tree

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.8.5
55
+++++
66

7+
* :pr:`346`: fix patch for sdpa_mask_recent_torch even if it was removed in transformers>=5.0
8+
79
0.8.4
810
+++++
911

_scripts/export_qwen25_vl_visual.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
2020
python export_qwen25_vl_visual.py -m Qwen/Qwen2.5-VL-7B-Instruct --device cpu --dtype float32 --exporter onnx-dynamo --pretrained --second-input
2121
22+
Merge model and data into one file:
23+
24+
.. code-block:: bash
25+
26+
tar -czvf model.tar.gz model.onnx model.data
27+
2228
Attention
2329
+++++++++
2430

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,83 @@ def test_sdpa_mask_recent_torch(self):
6262
got = patched_sdpa_mask_recent_torch(**kwargs)
6363
self.assertEqualArray(expected, got)
6464

65+
@requires_transformers("4.99")
66+
def test_sdpa_mask_recent_torch_is_running(self):
67+
def _copy_vmap_for_bhqkv(mask_function, bh_indices=True):
68+
dimensions = [(None, None, None, 0), (None, None, 0, None)]
69+
if bh_indices:
70+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
71+
for dims in dimensions:
72+
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
73+
return mask_function
74+
75+
def copy_of_sdpa_mask_recent_torch(
76+
batch_size,
77+
cache_position,
78+
kv_length,
79+
kv_offset=0,
80+
mask_function=transformers.masking_utils.causal_mask_function,
81+
attention_mask=None,
82+
local_size=None,
83+
allow_is_causal_skip=True,
84+
**kwargs,
85+
):
86+
q_length = cache_position.shape[0]
87+
padding_mask = transformers.masking_utils.prepare_padding_mask(
88+
attention_mask, kv_length, kv_offset
89+
)
90+
if allow_is_causal_skip and transformers.masking_utils._ignore_causal_mask_sdpa(
91+
padding_mask, q_length, kv_length, kv_offset, local_size
92+
):
93+
return None
94+
kv_arange = torch.arange(kv_length, device=cache_position.device)
95+
kv_arange += kv_offset
96+
if padding_mask is not None:
97+
mask_function = transformers.masking_utils.and_masks(
98+
mask_function,
99+
transformers.masking_utils.padding_mask_function(padding_mask),
100+
)
101+
102+
batch_arange = torch.arange(batch_size, device=cache_position.device)
103+
head_arange = torch.arange(1, device=cache_position.device)
104+
with transformers.masking_utils.TransformGetItemToIndex():
105+
causal_mask = _copy_vmap_for_bhqkv(mask_function)(
106+
batch_arange, head_arange, cache_position, kv_arange
107+
)
108+
return causal_mask
109+
110+
sdpa_mask_recent_torch = copy_of_sdpa_mask_recent_torch
111+
patched_sdpa_mask_recent_torch = patch_transformers.patched_sdpa_mask_recent_torch
112+
kwargs = {
113+
"batch_size": 1,
114+
"cache_position": torch.tensor([3], dtype=torch.int64),
115+
"kv_length": 4,
116+
"kv_offset": 0,
117+
"mask_function": transformers.masking_utils.causal_mask_function,
118+
"attention_mask": torch.tensor([[True, True, True, True]]),
119+
"local_size": None,
120+
"allow_is_causal_skip": True,
121+
"allow_is_bidirectional_skip": False,
122+
}
123+
expected = sdpa_mask_recent_torch(**kwargs)
124+
got = patched_sdpa_mask_recent_torch(**kwargs)
125+
self.assertEqual(expected, got)
126+
127+
kwargs = {
128+
"batch_size": 1,
129+
"cache_position": torch.tensor([3], dtype=torch.int64),
130+
"kv_length": 4,
131+
"kv_offset": 0,
132+
"mask_function": transformers.masking_utils.causal_mask_function,
133+
"attention_mask": torch.tensor([[True, True, True, True]]),
134+
"local_size": None,
135+
"allow_is_causal_skip": False,
136+
"allow_is_bidirectional_skip": False,
137+
}
138+
expected = sdpa_mask_recent_torch(**kwargs)
139+
got = patched_sdpa_mask_recent_torch(**kwargs)
140+
self.assertEqualArray(expected, got)
141+
65142
def test_sdpa_attention_forward_not_causal(self):
66143
sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
67144
patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from typing import Callable, List, Optional, Tuple
23
import torch
34

@@ -19,6 +20,12 @@
1920
prepare_padding_mask,
2021
)
2122

23+
_prepare_padding_mask_kwargs = (
24+
dict(_slice=False)
25+
if "_slice" in inspect.signature(prepare_padding_mask).parameters
26+
else {}
27+
)
28+
2229
try:
2330
# transformers>=5.0
2431
from transformers.masking_utils import (
@@ -132,7 +139,9 @@ def patched_sdpa_mask_recent_torch(
132139
) -> Optional[torch.Tensor]:
133140
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
134141
q_length = cache_position.shape[0]
135-
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
142+
padding_mask = prepare_padding_mask(
143+
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
144+
)
136145
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
137146
padding_mask, q_length, kv_length, kv_offset, local_size
138147
):

0 commit comments

Comments
 (0)