Skip to content

Commit b8a5de3

Browse files
committed
keep type
1 parent cf4b8a2 commit b8a5de3

3 files changed

Lines changed: 38 additions & 13 deletions

File tree

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`329`: supports lists with OnnxruntimeEvaluator
78
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
89
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
910
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,9 @@ def test_ort_eval_loop_seq(self):
433433
got = ev.run(None, feeds)
434434
self.assertEqual((6,), got[0].shape)
435435
self.assertEqualArray(
436-
np.array([1.0, 1.0, 2.0, 1.0, 2.0, 3.0], dtype=np.float32), got[0]
436+
torch.tensor([1.0, 1.0, 2.0, 1.0, 2.0, 3.0], dtype=torch.float32), got[0]
437437
)
438+
self.assertIsInstance(got[0], torch.Tensor)
438439

439440

440441
if __name__ == "__main__":

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
to_array_extended,
2424
np_dtype_to_tensor_dtype,
2525
)
26-
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype
26+
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
2727
from ..helpers.ort_session import (
2828
InferenceSessionForTorch,
2929
InferenceSessionForNumpy,
@@ -48,22 +48,38 @@ def __init__(self, itype: Union[list, int]):
4848
self.dtype = onnx_dtype_to_torch_dtype(itype)
4949
else:
5050
assert itype, "The list cannot be created with an empty list."
51-
self.itype = np_dtype_to_tensor_dtype(itype[0].dtype)
51+
self.itype = (
52+
np_dtype_to_tensor_dtype(itype[0].dtype)
53+
if isinstance(itype[0], np.ndarray)
54+
else torch_dtype_to_onnx_dtype(itype[0].dtype)
55+
)
5256
self.extend(itype)
5357
self.dtype = itype[0].dtype
5458
self.shape = "OnnxList"
5559

5660
def get_device(self):
61+
"Returns the device of the first tensor."
5762
assert len(self) > 0, "Cannot access the device for an empty list."
5863
return self[0].get_device() if hasattr(self[0], "get_device") else -1
5964

6065
def numpy(self):
66+
"Creates a new list with all tensors on numpy or self it is already the case."
6167
if all(isinstance(v, np.ndarray) for v in self):
6268
return self
63-
res = OnnxList()
64-
for v in self:
65-
res.append(v.detach().cpu().numpy())
66-
return res
69+
return OnnxList([v.detach().cpu().numpy() for v in self])
70+
71+
def to(self, tensor_like) -> "OnnxList":
72+
"Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`."
73+
if isinstance(tensor_like, np.ndarray):
74+
return self
75+
import torch
76+
77+
return OnnxList(
78+
[
79+
torch.from_numpy(t).to(tensor_like.device) if isinstance(t, np.ndarray) else t
80+
for t in self
81+
]
82+
)
6783

6884

6985
class OnnxruntimeEvaluator:
@@ -245,7 +261,7 @@ def _log_arg(self, a: Any) -> Any:
245261
if isinstance(a, (str, int, float)):
246262
return a
247263
if isinstance(a, OnnxList):
248-
return f"#{len(a)}[]"
264+
return string_type(a)
249265
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
250266
if hasattr(a, "shape"):
251267
prefix = "A:" if hasattr(a, "astype") else "T:"
@@ -683,8 +699,12 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
683699

684700
if node.op_type[0] == "C":
685701
if node.op_type == "ConcatFromSequence":
686-
# sess._run_init(feeds )
687-
return list(sess.sess.run(None, self.feeds_to_numpy(feeds)))
702+
res = sess.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
703+
if isinstance(inputs[0][0], np.ndarray):
704+
return list(res)
705+
import torch
706+
707+
return [torch.from_numpy(r).to(inputs[0][0].device) for r in res]
688708

689709
outputs = list(sess.run(None, feeds))
690710
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
@@ -778,9 +798,12 @@ def _run_scan_or_loop(
778798
feeds = {name: results[name] for name in sess.input_names}
779799
if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()):
780800
# This operator uses sequence. onnxruntime does not play well with sequence.
781-
sess._run_init(feeds)
782-
outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds))
783-
return [(OnnxList(v) if isinstance(v, list) else v) for v in outputs]
801+
sess._run_init(feeds) # type: ignore[union-attr]
802+
outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
803+
return [
804+
(OnnxList(v).to(feeds[node.input[0]]) if isinstance(v, list) else v)
805+
for v in outputs
806+
]
784807

785808
outputs = sess.run(None, feeds)
786809
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"

0 commit comments

Comments
 (0)