Skip to content

Commit a9d2eb6

Browse files
committed
fix
1 parent 7c015dc commit a9d2eb6

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,14 @@ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.d
262262
itype = torch_dtype_to_onnx_dtype(dtype)
263263
if strategy is not None:
264264
return strategy, itype
265-
if dtype == torch.float32:
265+
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
266266
if opset >= 24:
267267
return "LOOPA24", itype
268268
return "LOOPMHA", itype
269-
if dtype == torch.float16:
270-
if first_tensor.is_cuda:
269+
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270+
# first_tensor may be a SymbolicTensor (onnx).
271+
# is_cuda is not available.
272+
if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
271273
return "PACKED", itype
272274
return "LOOPMHA", itype
273275
raise AssertionError(

0 commit comments

Comments
 (0)