Skip to content

Commit 8f74903

Browse files
committed
doc
1 parent 9c83ab3 commit 8f74903

4 files changed

Lines changed: 15 additions & 14 deletions

File tree

onnx_diagnostic/export/onnx_plug.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def forward(self, x):
143143
print(pretty_onnx(onx))
144144
145145
This shows how to define multiple versions depending on the device,
146-
the type or the targetted onnx opset.
146+
the type or the targeted onnx opset.
147147
148148
.. code-block:: python
149149

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Union,
1818
)
1919
import numpy as np
20-
import numpy.typing as npt
2120
import onnx
2221
import onnx.helper as oh
2322
import onnx.numpy_helper as onh
@@ -33,6 +32,8 @@
3332
load as onnx_load,
3433
)
3534

35+
TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
36+
3637

3738
def _make_stat(init: TensorProto) -> Dict[str, float]:
3839
"""
@@ -490,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None:
490491
tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
491492

492493

493-
def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
494+
def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto:
494495
"""
495496
Converts a numpy array to a tensor def assuming the dtype
496497
is defined in ml_dtypes.
@@ -536,7 +537,7 @@ def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> Tens
536537
}
537538

538539

539-
def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
540+
def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
540541
"""
541542
Converts an array into a :class:`onnx.TensorProto`.
542543
@@ -603,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
603604
return t
604605

605606

606-
def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
607+
def to_array_extended(proto: TensorProto) -> TensorLike:
607608
"""Converts :class:`onnx.TensorProto` into a numpy array."""
608609
arr = onh.to_array(proto)
609610
if proto.data_type >= onnx.TensorProto.BFLOAT16:

onnx_diagnostic/helpers/ort_session.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
import onnx
33
import numpy as np
4-
import numpy.typing as npt
54
import torch
65
from torch._C import _from_dlpack
76
import onnxruntime
@@ -16,6 +15,7 @@
1615

1716

1817
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
18+
TensorLike = Union[np.ndarray, torch.Tensor]
1919

2020

2121
class _InferenceSession:
@@ -243,16 +243,16 @@ def __init__(
243243
)
244244

245245
def run(
246-
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
247-
) -> List[Optional[npt.ArrayLike]]:
246+
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
247+
) -> List[Optional[TensorLike]]:
248248
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
249249
# sess.run does not support blfoat16
250250
# res = self.sess.run(output_names, feeds)
251251
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
252252

253253
def run_dlpack(
254-
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
255-
) -> Tuple[Optional[npt.ArrayLike], ...]:
254+
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
255+
) -> Tuple[Optional[TensorLike], ...]:
256256
"""
257257
Same as :meth:`onnxruntime.InferenceSession.run` except that
258258
feeds is a dictionary of :class:`np.ndarray`.
@@ -289,13 +289,13 @@ def run_dlpack(
289289
def _ortvalues_to_numpy_tensor(
290290
self,
291291
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
292-
) -> Tuple[Optional[npt.ArrayLike], ...]:
292+
) -> Tuple[Optional[TensorLike], ...]:
293293
if len(ortvalues) == 0:
294294
return tuple()
295295

296296
if self.nvtx:
297297
self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
298-
res: List[Optional[npt.ArrayLike]] = [] # noqa: F823
298+
res: List[Optional[TensorLike]] = [] # noqa: F823
299299
for i in range(len(ortvalues)):
300300
if not ortvalues[i].has_value():
301301
res.append(None)
@@ -556,7 +556,7 @@ def investigate_onnxruntime_issue(
556556
Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
557557
] = None,
558558
# if model needs to be run.
559-
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None,
559+
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None,
560560
verbose: int = 0,
561561
dump_filename: Optional[str] = None,
562562
infer_shapes: bool = True,

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ def study_discrepancies(
11071107

11081108
def int_device_to_torch_device(device_id: int) -> torch.device:
11091109
"""
1110-
Converts a device defined as an integer (coming from :meth:`torch.get_device`)
1110+
Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`)
11111111
into a ``torch.device``.
11121112
"""
11131113
if device_id < 0:

0 commit comments

Comments
 (0)