|
1 | 1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
2 | 2 | import onnx |
3 | 3 | import numpy as np |
4 | | -import numpy.typing as npt |
5 | 4 | import torch |
6 | 5 | from torch._C import _from_dlpack |
7 | 6 | import onnxruntime |
|
16 | 15 |
|
17 | 16 |
|
18 | 17 | DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)} |
| 18 | +TensorLike = Union[np.ndarray, torch.Tensor] |
19 | 19 |
|
20 | 20 |
|
21 | 21 | class _InferenceSession: |
@@ -243,16 +243,16 @@ def __init__( |
243 | 243 | ) |
244 | 244 |
|
245 | 245 | def run( |
246 | | - self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike] |
247 | | - ) -> List[Optional[npt.ArrayLike]]: |
| 246 | + self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike] |
| 247 | + ) -> List[Optional[TensorLike]]: |
248 | 248 | """Calls :meth:`onnxruntime.InferenceSession.run`.""" |
249 | 249 | # sess.run does not support blfoat16 |
250 | 250 | # res = self.sess.run(output_names, feeds) |
251 | 251 | return self._post_process_inplace(list(self.run_dlpack(output_names, feeds))) |
252 | 252 |
|
253 | 253 | def run_dlpack( |
254 | | - self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike] |
255 | | - ) -> Tuple[Optional[npt.ArrayLike], ...]: |
| 254 | + self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike] |
| 255 | + ) -> Tuple[Optional[TensorLike], ...]: |
256 | 256 | """ |
257 | 257 | Same as :meth:`onnxruntime.InferenceSession.run` except that |
258 | 258 | feeds is a dictionary of :class:`np.ndarray`. |
@@ -289,13 +289,13 @@ def run_dlpack( |
289 | 289 | def _ortvalues_to_numpy_tensor( |
290 | 290 | self, |
291 | 291 | ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector], |
292 | | - ) -> Tuple[Optional[npt.ArrayLike], ...]: |
| 292 | + ) -> Tuple[Optional[TensorLike], ...]: |
293 | 293 | if len(ortvalues) == 0: |
294 | 294 | return tuple() |
295 | 295 |
|
296 | 296 | if self.nvtx: |
297 | 297 | self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor") |
298 | | - res: List[Optional[npt.ArrayLike]] = [] # noqa: F823 |
| 298 | + res: List[Optional[TensorLike]] = [] # noqa: F823 |
299 | 299 | for i in range(len(ortvalues)): |
300 | 300 | if not ortvalues[i].has_value(): |
301 | 301 | res.append(None) |
@@ -556,7 +556,7 @@ def investigate_onnxruntime_issue( |
556 | 556 | Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]] |
557 | 557 | ] = None, |
558 | 558 | # if model needs to be run. |
559 | | - feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None, |
| 559 | + feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None, |
560 | 560 | verbose: int = 0, |
561 | 561 | dump_filename: Optional[str] = None, |
562 | 562 | infer_shapes: bool = True, |
|
0 commit comments