Skip to content

Commit a8036a9

Browse files
committed
fix default strategy
1 parent dcfc731 commit a8036a9

3 files changed

Lines changed: 22 additions & 23 deletions

File tree

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def test_sbs_with_loops(self):
682682
PLUGS_Qwen25,
683683
)
684684
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
685-
qwen_sdpa_attention_loopmha_versatile,
685+
qwen_sdpa_attention_versatile,
686686
)
687687

688688
class Model(torch.nn.Module):
@@ -693,9 +693,7 @@ def forward(self, query, key, value, seq_lens):
693693
qs = query * mask
694694
ks = key * mask
695695
vs = value * mask
696-
attn_output = qwen_sdpa_attention_loopmha_versatile(
697-
qs, ks, vs, seq_lens, 0.11, 16
698-
)
696+
attn_output = qwen_sdpa_attention_versatile(qs, ks, vs, seq_lens, 0.11, 16)
699697
red = attn_output.mean(dim=-1, keepdim=True)
700698
return attn_output - red
701699

onnx_diagnostic/export/onnx_plug.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,18 @@ def forward(self, x):
150150
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
151151
first_tensor = next(a for a in args if a is not None)
152152
dtype = first_tensor.dtype
153-
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
154-
if strategy is not None:
155-
return strategy, dtype
153+
itype = torch_dtype_to_onnx_dtype(dtype)
156154
if dtype == torch.float32:
157155
if opset >= 24:
158-
return "LOOPA24", dtype
159-
return "LOOPMHA", dtype
156+
return "LOOPA24", itype
157+
return "LOOPMHA", itype
160158
if dtype == torch.float16:
161159
if first_tensor.is_cuda:
162-
return "PACKED", dtype
163-
return "LOOPMHA", dtype
160+
return "PACKED", itype
161+
return "LOOPMHA", itype
164162
raise AssertionError(
165-
f"Unable to handle type {torch.dtype} on "
166-
f"device {torch.device} with opset={opset}"
163+
f"Unable to handle type {torch.dtype} (itype={itype}) "
164+
f"on device {torch.device} with opset={opset}"
167165
)
168166
169167
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
@@ -338,6 +336,8 @@ def _register(self):
338336
input_args.append(f"int {p}={val}")
339337
elif isinstance(val, float):
340338
input_args.append(f"float {p}={val}")
339+
elif isinstance(val, str):
340+
input_args.append(f"str {p}={val}")
341341
else:
342342
raise NotImplementedError(
343343
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
@@ -445,7 +445,7 @@ def converter(
445445
*args,
446446
**kwargs,
447447
) -> Any:
448-
has_devices = [a for a in args if g.has_device(a)]
448+
has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
449449
assert (
450450
has_devices
451451
), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn.functional as F
77
from ...export.onnx_plug import EagerDirectReplacementWithOnnx
8+
from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
89
from .patch_helper import _is_torchdynamo_exporting
910
from ._patch_transformers_attention import patched_sdpa_attention_forward
1011

@@ -225,18 +226,20 @@ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.d
225226
first_tensor = next(a for a in args if a is not None)
226227
dtype = first_tensor.dtype
227228
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
229+
itype = torch_dtype_to_onnx_dtype(dtype)
228230
if strategy is not None:
229-
return strategy, dtype
231+
return strategy, itype
230232
if dtype == torch.float32:
231233
if opset >= 24:
232-
return "LOOPA24", dtype
233-
return "LOOPMHA", dtype
234+
return "LOOPA24", itype
235+
return "LOOPMHA", itype
234236
if dtype == torch.float16:
235237
if first_tensor.is_cuda:
236-
return "PACKED", dtype
237-
return "LOOPMHA", dtype
238+
return "PACKED", itype
239+
return "LOOPMHA", itype
238240
raise AssertionError(
239-
f"Unable to handle type {torch.dtype} on device {torch.device} with opset={opset}"
241+
f"Unable to handle type {torch.dtype} (itype={itype}) "
242+
f"on device {torch.device} with opset={opset}"
240243
)
241244

242245
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
@@ -558,9 +561,7 @@ class patched_Qwen2_5_VLVisionAttention:
558561
_PATCHED_CLASS_ = (
559562
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
560563
)
561-
STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
562-
"QWEN25ATTENTION", "PACKED"
563-
)
564+
STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
564565

565566
def forward(
566567
self,

0 commit comments

Comments
 (0)