Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.4
+++++

* :pr:`340`: supports devices in onnx plugs
* :pr:`338`: fixes ReplayConfiguration.dump, add function to select of part of a model
* :pr:`337`: fixes extract_subset_of_nodes
* :pr:`336`: implements versioned onnx plugs
Expand Down
7 changes: 4 additions & 3 deletions _unittests/ut_tasks/try_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_qwen25_vli_visual(self):
.. code-block:: bash

NEVERTEST=1 \\
QWEN25ATTENTION=BIGMASK \\
PRETRAINED=1 \\
TESTDEVICE=cuda \\
TESTDTYPE=float16 \\
Expand Down Expand Up @@ -164,9 +163,11 @@ def _config_reduction(config, task):
if qwen25_attention:
attention_options = [qwen25_attention]
elif device == "cuda" and dtype in ("float16", "bfloat16"):
attention_options = ["PACKED", "BIGMASK"]
attention_options = [
"PACKED",
]
else:
attention_options = ["LOOPMHA", "LOOPA24", "BIGMASK"]
attention_options = ["LOOPMHA", "LOOPA24"]

# fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
for attention in attention_options:
Expand Down
6 changes: 2 additions & 4 deletions _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def test_sbs_with_loops(self):
PLUGS_Qwen25,
)
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopmha_versatile,
qwen_sdpa_attention_versatile,
)

class Model(torch.nn.Module):
Expand All @@ -693,9 +693,7 @@ def forward(self, query, key, value, seq_lens):
qs = query * mask
ks = key * mask
vs = value * mask
attn_output = qwen_sdpa_attention_loopmha_versatile(
qs, ks, vs, seq_lens, 0.11, 16
)
attn_output = qwen_sdpa_attention_versatile(qs, ks, vs, seq_lens, 0.11, 16)
red = attn_output.mean(dim=-1, keepdim=True)
return attn_output - red

Expand Down
100 changes: 86 additions & 14 deletions onnx_diagnostic/export/onnx_plug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import onnx
import torch
from ..helpers import max_diff, string_type
from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
from ..helpers.torch_helper import (
torch_dtype_to_onnx_dtype,
onnx_dtype_to_torch_dtype,
int_device_to_torch_device,
)
from ..reference import OnnxruntimeEvaluator

TUPLE_TENSORS = Tuple[torch.Tensor, ...]
Expand Down Expand Up @@ -50,7 +54,10 @@ class EagerDirectReplacementWithOnnx:
only tensors must be counted
:param name: the name of the custom op, the function name if not specified
:param kwargs: constants parameters with their default values
:param version_selector: selects the version based on the arguments
:param version_selector: selects the version based on the arguments,
see below for an example, this allows the user to define different
onnx version depending on the inputs
:param default_opset: opset to use by default
:param verbose: verbose level

Here is an example:
Expand Down Expand Up @@ -134,6 +141,58 @@ def forward(self, x):
).model_proto

print(pretty_onnx(onx))

This shows how to define multiple versions depending on the device,
the type or the targeted onnx opset.

.. code-block:: python

def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
first_tensor = next(a for a in args if a is not None)
dtype = first_tensor.dtype
itype = torch_dtype_to_onnx_dtype(dtype)
if dtype == torch.float32:
if opset >= 24:
return "LOOPA24", itype
return "LOOPMHA", itype
if dtype == torch.float16:
if first_tensor.is_cuda:
return "PACKED", itype
return "LOOPMHA", itype
raise AssertionError(
f"Unable to handle type {torch.dtype} (itype={itype}) "
f"on device {torch.device} with opset={opset}"
)

qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
qwen_sdpa_attention,
lambda qs, *args, **kwargs: torch.empty(
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
dtype=qs.dtype,
device=qs.device,
),
{
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
PackedAttention.to_function_proto()
),
("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
),
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
LoopMHAAttention.to_function_proto()
),
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
onnx.TensorProto.FLOAT16,
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
),
},
n_inputs=4,
n_outputs=1,
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
name="qwen_sdpa_attention_versatile",
version_selector=qwen_version_selector,
)
"""

def __init__(
Expand All @@ -146,7 +205,8 @@ def __init__(
name: Optional[str] = None,
kwargs: Optional[Dict[str, Union[int, float]]] = None,
verbose: int = 0,
version_selector: Optional[Callable[[Any], Any]] = None,
version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
default_opset: int = 22,
):
assert isinstance(function_proto, onnx.FunctionProto) or (
isinstance(function_proto, dict)
Expand Down Expand Up @@ -183,6 +243,7 @@ def __init__(
self.verbose = verbose
self.custom_op = self._register()
self.version_selector = version_selector
self.default_opset = default_opset
self._check_protos(params)

def _check_protos(self, params):
Expand Down Expand Up @@ -221,21 +282,18 @@ def _check_protos(self, params):
not self._function_proto_versioned or self.version_selector
), "version_selector is needed when multiple protos are given."

def get_function_proto(self, *args) -> onnx.FunctionProto:
def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
"""Returns the correct version based on the inputs."""
if self._function_proto:
return self._function_proto
if (
len(args) == 1
and isinstance(args[0], (int, str))
and args[0] in self._function_proto_versioned
):
return self._function_proto_versioned[args[0]]
assert isinstance(
opset, int
), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
assert any(
a is not None for a in args
), f"Unexpected args={string_type(args, with_shape=True)}"
try:
key = self.version_selector(*args) # type: ignore[misc]
key = self.version_selector(opset, *args) # type: ignore[misc]
except (ValueError, AttributeError) as e:
raise AssertionError(
f"Unable to select a version, fails to get a key, available="
Expand Down Expand Up @@ -278,6 +336,8 @@ def _register(self):
input_args.append(f"int {p}={val}")
elif isinstance(val, float):
input_args.append(f"float {p}={val}")
elif isinstance(val, str):
input_args.append(f"str {p}={val}")
else:
raise NotImplementedError(
f"kwargs {p!r} has a default value of unsupported type {type(val)}"
Expand All @@ -302,6 +362,7 @@ def verify(
*args,
engine: Optional[Callable] = None,
dump_onnx_model: Optional[str] = None,
opset: int = 22,
**kwargs,
) -> VerifyResult:
"""
Expand All @@ -316,6 +377,7 @@ def verify(
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
:param dump_onnx_model: to dump the onnx model used to verify
eager and onnx produce the same results
:param opset: onnx opset to use
:param kwargs: additional arguments to the function
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
"""
Expand Down Expand Up @@ -350,7 +412,7 @@ def verify(
assert engine is None, f"Not implemented yet with engine={engine!r}"
ags, kws = self._make_args_kwargs(*args, **kwargs)
sess = OnnxruntimeEvaluator(
self.get_function_proto(*args),
self.get_function_proto(opset, *args),
whole=True,
dump_onnx_model=dump_onnx_model,
function_kwargs=kws,
Expand Down Expand Up @@ -383,7 +445,17 @@ def converter(
*args,
**kwargs,
) -> Any:
function_proto = self.get_function_proto(g.get_type(args[0]))
has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
assert (
has_devices
), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
arg_device = has_devices[0]
fake_tensor = torch.empty(
tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
device=int_device_to_torch_device(g.get_device(arg_device)),
)
function_proto = self.get_function_proto(g.main_opset, fake_tensor)
if not g.has_local_function(function_proto.name, domain=function_proto.domain):
g.add_function(function_proto)
ags, kws = self._make_args_kwargs(*args, **kwargs)
Expand Down Expand Up @@ -417,7 +489,7 @@ def onnx_dynamo_converter(self) -> Callable:
onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)

def get_proto(*args):
function_proto = self.get_function_proto(*args)
function_proto = self.get_function_proto(self.default_opset, *args)
schema = onnx_plug_op[function_proto.name]
if schema is None:
all_types = [
Expand Down
9 changes: 5 additions & 4 deletions onnx_diagnostic/helpers/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Union,
)
import numpy as np
import numpy.typing as npt
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
Expand All @@ -33,6 +32,8 @@
load as onnx_load,
)

TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821


def _make_stat(init: TensorProto) -> Dict[str, float]:
"""
Expand Down Expand Up @@ -490,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None:
tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()


def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto:
"""
Converts a numpy array to a tensor def assuming the dtype
is defined in ml_dtypes.
Expand Down Expand Up @@ -536,7 +537,7 @@ def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> Tens
}


def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
"""
Converts an array into a :class:`onnx.TensorProto`.

Expand Down Expand Up @@ -603,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
return t


def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
def to_array_extended(proto: TensorProto) -> TensorLike:
"""Converts :class:`onnx.TensorProto` into a numpy array."""
arr = onh.to_array(proto)
if proto.data_type >= onnx.TensorProto.BFLOAT16:
Expand Down
16 changes: 8 additions & 8 deletions onnx_diagnostic/helpers/ort_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import onnx
import numpy as np
import numpy.typing as npt
import torch
from torch._C import _from_dlpack
import onnxruntime
Expand All @@ -16,6 +15,7 @@


DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
TensorLike = Union[np.ndarray, torch.Tensor]


class _InferenceSession:
Expand Down Expand Up @@ -243,16 +243,16 @@ def __init__(
)

def run(
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
) -> List[Optional[npt.ArrayLike]]:
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
) -> List[Optional[TensorLike]]:
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
# sess.run does not support blfoat16
# res = self.sess.run(output_names, feeds)
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))

def run_dlpack(
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
) -> Tuple[Optional[npt.ArrayLike], ...]:
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
) -> Tuple[Optional[TensorLike], ...]:
"""
Same as :meth:`onnxruntime.InferenceSession.run` except that
feeds is a dictionary of :class:`np.ndarray`.
Expand Down Expand Up @@ -289,13 +289,13 @@ def run_dlpack(
def _ortvalues_to_numpy_tensor(
self,
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
) -> Tuple[Optional[npt.ArrayLike], ...]:
) -> Tuple[Optional[TensorLike], ...]:
if len(ortvalues) == 0:
return tuple()

if self.nvtx:
self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
res: List[Optional[npt.ArrayLike]] = [] # noqa: F823
res: List[Optional[TensorLike]] = [] # noqa: F823
for i in range(len(ortvalues)):
if not ortvalues[i].has_value():
res.append(None)
Expand Down Expand Up @@ -556,7 +556,7 @@ def investigate_onnxruntime_issue(
Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
] = None,
# if model needs to be run.
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None,
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None,
verbose: int = 0,
dump_filename: Optional[str] = None,
infer_shapes: bool = True,
Expand Down
10 changes: 10 additions & 0 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,3 +1103,13 @@ def study_discrepancies(
if name:
fig.savefig(name)
return ax


def int_device_to_torch_device(device_id: int) -> torch.device:
"""
Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`)
into a ``torch.device``.
"""
if device_id < 0:
return torch.device("cpu")
return torch.device("cuda", device_id)
Loading
Loading