Skip to content

Commit afe3c68

Browse files
committed
add device
1 parent d3eb796 commit afe3c68

3 files changed

Lines changed: 132 additions & 135 deletions

File tree

onnx_diagnostic/export/onnx_plug.py

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import onnx
55
import torch
66
from ..helpers import max_diff, string_type
7-
from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
7+
from ..helpers.torch_helper import (
8+
torch_dtype_to_onnx_dtype,
9+
onnx_dtype_to_torch_dtype,
10+
int_device_to_torch_device,
11+
)
812
from ..reference import OnnxruntimeEvaluator
913

1014
TUPLE_TENSORS = Tuple[torch.Tensor, ...]
@@ -50,7 +54,10 @@ class EagerDirectReplacementWithOnnx:
5054
only tensors must be counted
5155
:param name: the name of the custom op, the function name if not specified
5256
:param kwargs: constants parameters with their default values
53-
:param version_selector: selects the version based on the arguments
57+
:param version_selector: selects the version based on the arguments,
58+
see below for an example, this allows the user to define different
59+
onnx version depending on the inputs
60+
:param default_opset: opset to use by default
5461
:param verbose: verbose level
5562
5663
Here is an example:
@@ -134,6 +141,60 @@ def forward(self, x):
134141
).model_proto
135142
136143
print(pretty_onnx(onx))
144+
145+
This shows how to define multiple versions depending on the device,
146+
the type or the targetted onnx opset.
147+
148+
.. code-block:: python
149+
150+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
151+
first_tensor = next(a for a in args if a is not None)
152+
dtype = first_tensor.dtype
153+
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
154+
if strategy is not None:
155+
return strategy, dtype
156+
if dtype == torch.float32:
157+
if opset >= 24:
158+
return "LOOPA24", dtype
159+
return "LOOPMHA", dtype
160+
if dtype == torch.float16:
161+
if first_tensor.is_cuda:
162+
return "PACKED", dtype
163+
return "LOOPMHA", dtype
164+
raise AssertionError(
165+
f"Unable to handle type {torch.dtype} on "
166+
f"device {torch.device} with opset={opset}"
167+
)
168+
169+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
170+
qwen_sdpa_attention,
171+
lambda qs, *args, **kwargs: torch.empty(
172+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
173+
dtype=qs.dtype,
174+
device=qs.device,
175+
),
176+
{
177+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
178+
PackedAttention.to_function_proto()
179+
),
180+
("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
181+
("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
182+
onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
183+
),
184+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
185+
LoopMHAAttention.to_function_proto()
186+
),
187+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
188+
onnx.TensorProto.FLOAT16,
189+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
190+
),
191+
},
192+
n_inputs=4,
193+
n_outputs=1,
194+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
195+
name="qwen_sdpa_attention_versatile",
196+
version_selector=qwen_version_selector,
197+
)
137198
"""
138199

139200
def __init__(
@@ -146,7 +207,8 @@ def __init__(
146207
name: Optional[str] = None,
147208
kwargs: Optional[Dict[str, Union[int, float]]] = None,
148209
verbose: int = 0,
149-
version_selector: Optional[Callable[[Any], Any]] = None,
210+
version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
211+
default_opset: int = 22,
150212
):
151213
assert isinstance(function_proto, onnx.FunctionProto) or (
152214
isinstance(function_proto, dict)
@@ -183,6 +245,7 @@ def __init__(
183245
self.verbose = verbose
184246
self.custom_op = self._register()
185247
self.version_selector = version_selector
248+
self.default_opset = default_opset
186249
self._check_protos(params)
187250

188251
def _check_protos(self, params):
@@ -221,21 +284,18 @@ def _check_protos(self, params):
221284
not self._function_proto_versioned or self.version_selector
222285
), "version_selector is needed when multiple protos are given."
223286

224-
def get_function_proto(self, *args) -> onnx.FunctionProto:
287+
def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
225288
"""Returns the correct version based on the inputs."""
226289
if self._function_proto:
227290
return self._function_proto
228-
if (
229-
len(args) == 1
230-
and isinstance(args[0], (int, str))
231-
and args[0] in self._function_proto_versioned
232-
):
233-
return self._function_proto_versioned[args[0]]
291+
assert isinstance(
292+
opset, int
293+
), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
234294
assert any(
235295
a is not None for a in args
236296
), f"Unexpected args={string_type(args, with_shape=True)}"
237297
try:
238-
key = self.version_selector(*args) # type: ignore[misc]
298+
key = self.version_selector(opset, *args) # type: ignore[misc]
239299
except (ValueError, AttributeError) as e:
240300
raise AssertionError(
241301
f"Unable to select a version, fails to get a key, available="
@@ -302,6 +362,7 @@ def verify(
302362
*args,
303363
engine: Optional[Callable] = None,
304364
dump_onnx_model: Optional[str] = None,
365+
opset: int = 22,
305366
**kwargs,
306367
) -> VerifyResult:
307368
"""
@@ -316,6 +377,7 @@ def verify(
316377
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
317378
:param dump_onnx_model: to dump the onnx model used to verify
318379
eager and onnx produce the same results
380+
:param opset: onnx opset to use
319381
:param kwargs: additional arguments to the function
320382
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
321383
"""
@@ -350,7 +412,7 @@ def verify(
350412
assert engine is None, f"Not implemented yet with engine={engine!r}"
351413
ags, kws = self._make_args_kwargs(*args, **kwargs)
352414
sess = OnnxruntimeEvaluator(
353-
self.get_function_proto(*args),
415+
self.get_function_proto(opset, *args),
354416
whole=True,
355417
dump_onnx_model=dump_onnx_model,
356418
function_kwargs=kws,
@@ -383,7 +445,17 @@ def converter(
383445
*args,
384446
**kwargs,
385447
) -> Any:
386-
function_proto = self.get_function_proto(g.get_type(args[0]))
448+
has_devices = [a for a in args if g.has_device(a)]
449+
assert (
450+
has_devices
451+
), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
452+
arg_device = has_devices[0]
453+
fake_tensor = torch.empty(
454+
tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
455+
dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
456+
device=int_device_to_torch_device(g.get_device(arg_device)),
457+
)
458+
function_proto = self.get_function_proto(g.main_opset, fake_tensor)
387459
if not g.has_local_function(function_proto.name, domain=function_proto.domain):
388460
g.add_function(function_proto)
389461
ags, kws = self._make_args_kwargs(*args, **kwargs)
@@ -417,7 +489,7 @@ def onnx_dynamo_converter(self) -> Callable:
417489
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
418490

419491
def get_proto(*args):
420-
function_proto = self.get_function_proto(*args)
492+
function_proto = self.get_function_proto(self.default_opset, *args)
421493
schema = onnx_plug_op[function_proto.name]
422494
if schema is None:
423495
all_types = [

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,3 +1103,13 @@ def study_discrepancies(
11031103
if name:
11041104
fig.savefig(name)
11051105
return ax
1106+
1107+
1108+
def int_device_to_torch_device(device_id: int) -> torch.device:
1109+
"""
1110+
Converts a device defined as an integer (coming from :meth:`torch.get_device`)
1111+
into a ``torch.device``.
1112+
"""
1113+
if device_id < 0:
1114+
return torch.device("cpu")
1115+
return torch.device("cuda", device_id)

0 commit comments

Comments
 (0)