Skip to content

Commit cf4b8a2

Browse files
committed
Support lists with OnnxruntimeEvaluator
1 parent 9360a96 commit cf4b8a2

4 files changed

Lines changed: 228 additions & 11 deletions

File tree

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ def test_enumerate_results_loop(self):
402402
new_axis=0,
403403
),
404404
],
405-
)
405+
),
406+
ir_version=10,
407+
opset_imports=[oh.make_opsetid("", 22)],
406408
)
407409
res = list(enumerate_results(model, "slice_start", verbose=2))
408410
self.assertEqual(len(res), 2)

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
TFLOAT = onnx.TensorProto.FLOAT
23+
TINT64 = onnx.TensorProto.INT64
2324

2425

2526
class TestOnnxruntimeEvaluator(ExtTestCase):
@@ -319,6 +320,122 @@ def test_function_proto_with_kwargs(self):
319320
got = sess.run(None, feeds)
320321
self.assertEqualArray(expected, got[0], atol=1e-5)
321322

323+
@hide_stdout()
324+
def test_ort_eval_loop_seq(self):
325+
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
326+
_mkv_ = oh.make_tensor_value_info
327+
model = oh.make_model(
328+
graph=oh.make_graph(
329+
name="loop_test",
330+
inputs=[
331+
oh.make_tensor_value_info("trip_count", TINT64, ["a"]),
332+
oh.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
333+
],
334+
outputs=[oh.make_tensor_value_info("res", TFLOAT, [])],
335+
nodes=[
336+
oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT),
337+
oh.make_node(
338+
"Loop",
339+
inputs=["trip_count", "cond", "seq_empty"],
340+
outputs=["seq_res"],
341+
body=oh.make_graph(
342+
[
343+
oh.make_node(
344+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
345+
),
346+
oh.make_node(
347+
"Constant",
348+
inputs=[],
349+
outputs=["x"],
350+
value=oh.make_tensor(
351+
name="const_tensor_x",
352+
data_type=TFLOAT,
353+
dims=x.shape,
354+
vals=x.flatten().astype(float),
355+
),
356+
),
357+
oh.make_node(
358+
"Constant",
359+
inputs=[],
360+
outputs=["one"],
361+
value=oh.make_tensor(
362+
name="const_tensor_one",
363+
data_type=TINT64,
364+
dims=(),
365+
vals=[1],
366+
),
367+
),
368+
oh.make_node(
369+
"Constant",
370+
inputs=[],
371+
outputs=["slice_start"],
372+
value=oh.make_tensor(
373+
name="const_tensor_zero",
374+
data_type=TINT64,
375+
dims=(1,),
376+
vals=[0],
377+
),
378+
),
379+
oh.make_node(
380+
"Add", inputs=["iter_count", "one"], outputs=["end"]
381+
),
382+
oh.make_node(
383+
"Constant",
384+
inputs=[],
385+
outputs=["axes"],
386+
value=oh.make_tensor(
387+
name="const_tensor_axes",
388+
data_type=TINT64,
389+
dims=(1,),
390+
vals=[0],
391+
),
392+
),
393+
oh.make_node(
394+
"Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"]
395+
),
396+
oh.make_node(
397+
"Slice",
398+
inputs=["x", "slice_start", "slice_end"],
399+
outputs=["slice_out"],
400+
),
401+
oh.make_node(
402+
"SequenceInsert",
403+
inputs=["seq_in", "slice_out"],
404+
outputs=["seq_out"],
405+
),
406+
],
407+
"loop_body",
408+
[
409+
_mkv_("iter_count", TINT64, []),
410+
_mkv_("cond_in", onnx.TensorProto.BOOL, []),
411+
oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None),
412+
],
413+
[
414+
_mkv_("cond_out", onnx.TensorProto.BOOL, []),
415+
oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None),
416+
],
417+
),
418+
),
419+
oh.make_node(
420+
"ConcatFromSequence",
421+
inputs=["seq_res"],
422+
outputs=["res"],
423+
axis=0,
424+
new_axis=0,
425+
),
426+
],
427+
),
428+
ir_version=10,
429+
opset_imports=[oh.make_opsetid("", 22)],
430+
)
431+
ev = OnnxruntimeEvaluator(model, verbose=10)
432+
feeds = dict(trip_count=torch.tensor([3], dtype=torch.int64), cond=torch.tensor(True))
433+
got = ev.run(None, feeds)
434+
self.assertEqual((6,), got[0].shape)
435+
self.assertEqualArray(
436+
np.array([1.0, 1.0, 2.0, 1.0, 2.0, 3.0], dtype=np.float32), got[0]
437+
)
438+
322439

323440
if __name__ == "__main__":
324441
unittest.main(verbosity=2)

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,9 @@ def test_loop(self):
11231123
new_axis=0,
11241124
),
11251125
],
1126-
)
1126+
),
1127+
ir_version=10,
1128+
opset_imports=[oh.make_opsetid("", 22)],
11271129
)
11281130
self._finalize_test(
11291131
model, torch.tensor(5, dtype=torch.int64), torch.tensor(1, dtype=torch.bool)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
FunctionProto,
77
ModelProto,
88
NodeProto,
9+
TensorProto,
910
TypeProto,
1011
ValueInfoProto,
1112
helper as oh,
@@ -16,7 +17,13 @@
1617
from onnx.defs import onnx_opset_version
1718
import onnxruntime
1819
from ..helpers import string_type
19-
from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
20+
from ..helpers.onnx_helper import (
21+
pretty_onnx,
22+
dtype_to_tensor_dtype,
23+
to_array_extended,
24+
np_dtype_to_tensor_dtype,
25+
)
26+
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype
2027
from ..helpers.ort_session import (
2128
InferenceSessionForTorch,
2229
InferenceSessionForNumpy,
@@ -31,6 +38,34 @@
3138
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
3239

3340

41+
class OnnxList(list):
42+
"""Defines a list for the runtime."""
43+
44+
def __init__(self, itype: Union[list, int]):
45+
super().__init__()
46+
if isinstance(itype, int):
47+
self.itype = itype
48+
self.dtype = onnx_dtype_to_torch_dtype(itype)
49+
else:
50+
assert itype, "The list cannot be created with an empty list."
51+
self.itype = np_dtype_to_tensor_dtype(itype[0].dtype)
52+
self.extend(itype)
53+
self.dtype = itype[0].dtype
54+
self.shape = "OnnxList"
55+
56+
def get_device(self):
57+
assert len(self) > 0, "Cannot access the device for an empty list."
58+
return self[0].get_device() if hasattr(self[0], "get_device") else -1
59+
60+
def numpy(self):
61+
if all(isinstance(v, np.ndarray) for v in self):
62+
return self
63+
res = OnnxList()
64+
for v in self:
65+
res.append(v.detach().cpu().numpy())
66+
return res
67+
68+
3469
class OnnxruntimeEvaluator:
3570
"""
3671
This class loads an onnx model and the executes one by one the nodes
@@ -209,6 +244,8 @@ def output_types(self) -> List[TypeProto]:
209244
def _log_arg(self, a: Any) -> Any:
210245
if isinstance(a, (str, int, float)):
211246
return a
247+
if isinstance(a, OnnxList):
248+
return f"#{len(a)}[]"
212249
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
213250
if hasattr(a, "shape"):
214251
prefix = "A:" if hasattr(a, "astype") else "T:"
@@ -231,6 +268,12 @@ def _log(self, level: int, pattern: str, *args: Any) -> None:
231268
def _is_local_function(self, node: NodeProto) -> bool:
232269
return (node.domain, node.op_type) in self.local_functions
233270

271+
def _run_init(self, feed_inputs):
272+
if self.sess_ is None:
273+
assert self.proto, "self.proto is empty"
274+
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
275+
return self.sess_
276+
234277
def run(
235278
self,
236279
outputs: Optional[List[str]],
@@ -254,9 +297,7 @@ def run(
254297
"""
255298
if self.rt_nodes_ is None:
256299
# runs a whole
257-
if self.sess_ is None:
258-
assert self.proto, "self.proto is empty"
259-
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
300+
self._run_init(feed_inputs)
260301
assert self.sess_, "mypy not happy"
261302
return self.sess_.run(outputs, feed_inputs)
262303
if outputs is None:
@@ -283,7 +324,7 @@ def run(
283324
if node.op_type == "If" and node.domain == "":
284325
outputs = self._run_if(node, inputs, results)
285326
elif node.op_type in {"Scan", "Loop"} and node.domain == "":
286-
outputs = self._run_scan(node, inputs, results)
327+
outputs = self._run_scan_or_loop(node, inputs, results)
287328
elif self._is_local_function(node):
288329
outputs = self._run_local(node, inputs, results)
289330
else:
@@ -472,6 +513,18 @@ def _get_sess(
472513
)
473514
]
474515
prenodes = [] # type: ignore[var-annotated]
516+
elif node.op_type == "ConcatFromSequence" and node.domain == "":
517+
# We force the type to be a boolean.
518+
vinputs = [
519+
oh.make_value_info(
520+
node.input[0],
521+
type_proto=oh.make_sequence_type_proto(
522+
oh.make_tensor_type_proto(elem_type=inputs[0].itype, shape=None)
523+
),
524+
)
525+
]
526+
voutputs = [oh.make_tensor_value_info(node.output[0], inputs[0].itype, None)]
527+
prenodes = [] # type: ignore[var-annotated]
475528
else:
476529
unique_names = set()
477530
vinputs = []
@@ -535,7 +588,17 @@ def _get_sess_init_subgraph(
535588
if i == "" or i in unique_names:
536589
continue
537590
unique_names.add(i)
538-
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
591+
if isinstance(it, OnnxList):
592+
value = oh.make_value_info(
593+
i,
594+
type_proto=oh.make_sequence_type_proto(
595+
oh.make_tensor_type_proto(
596+
elem_type=dtype_to_tensor_dtype(it.dtype), shape=None
597+
)
598+
),
599+
)
600+
else:
601+
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
539602
vinputs.append(value)
540603

541604
reduced_set = self._get_hidden_inputs(g)
@@ -592,6 +655,14 @@ def _get_sess_local(
592655

593656
def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
594657
"""Runs a node."""
658+
if node.op_type[0] == "S":
659+
if node.op_type == "SequenceEmpty":
660+
dtype = TensorProto.FLOAT
661+
for att in node.attribute:
662+
if att.name == "dtype":
663+
dtype = att.i
664+
return [OnnxList(itype=dtype)]
665+
595666
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
596667
key = (id(node), *types)
597668
if key in self._cache:
@@ -609,6 +680,12 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
609680
continue
610681
feeds[i] = val
611682
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
683+
684+
if node.op_type[0] == "C":
685+
if node.op_type == "ConcatFromSequence":
686+
# sess._run_init(feeds )
687+
return list(sess.sess.run(None, self.feeds_to_numpy(feeds)))
688+
612689
outputs = list(sess.run(None, feeds))
613690
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
614691
return outputs
@@ -636,7 +713,7 @@ def _run_if(
636713
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
637714
return outputs
638715

639-
def _get_sess_scan(
716+
def _get_sess_scan_or_loop(
640717
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
641718
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
642719
g = None
@@ -671,7 +748,18 @@ def _get_sess_scan(
671748
)
672749
return onx, sess
673750

674-
def _run_scan(
751+
def feeds_to_numpy(self, feeds):
752+
new_feeds = {}
753+
for k, v in feeds.items():
754+
if hasattr(v, "detach"):
755+
new_feeds[k] = v.detach().cpu().numpy()
756+
elif isinstance(v, OnnxList):
757+
new_feeds[k] = v.numpy()
758+
else:
759+
new_feeds[k] = v
760+
return new_feeds
761+
762+
def _run_scan_or_loop(
675763
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
676764
) -> List[Any]:
677765
"""Runs a node Scan."""
@@ -682,10 +770,18 @@ def _run_scan(
682770
if key in self._cache:
683771
sess = self._cache[key][1]
684772
else:
685-
self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
773+
self._cache[key] = _onx, sess = self._get_sess_scan_or_loop(
774+
node, name, inputs, results
775+
)
686776

687777
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
688778
feeds = {name: results[name] for name in sess.input_names}
779+
if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()):
780+
# This operator uses sequence. onnxruntime does not play well with sequence.
781+
sess._run_init(feeds)
782+
outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds))
783+
return [(OnnxList(v) if isinstance(v, list) else v) for v in outputs]
784+
689785
outputs = sess.run(None, feeds)
690786
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
691787
return outputs

0 commit comments

Comments
 (0)