66 FunctionProto ,
77 ModelProto ,
88 NodeProto ,
9+ TensorProto ,
910 TypeProto ,
1011 ValueInfoProto ,
1112 helper as oh ,
1617from onnx .defs import onnx_opset_version
1718import onnxruntime
1819from ..helpers import string_type
19- from ..helpers .onnx_helper import pretty_onnx , dtype_to_tensor_dtype , to_array_extended
20+ from ..helpers .onnx_helper import (
21+ pretty_onnx ,
22+ dtype_to_tensor_dtype ,
23+ to_array_extended ,
24+ np_dtype_to_tensor_dtype ,
25+ )
26+ from ..helpers .torch_helper import onnx_dtype_to_torch_dtype
2027from ..helpers .ort_session import (
2128 InferenceSessionForTorch ,
2229 InferenceSessionForNumpy ,
3138Proto = Union [FunctionProto , ModelProto , GraphProto , NodeProto ]
3239
3340
41+ class OnnxList (list ):
42+ """Defines a list for the runtime."""
43+
44+ def __init__ (self , itype : Union [list , int ]):
45+ super ().__init__ ()
46+ if isinstance (itype , int ):
47+ self .itype = itype
48+ self .dtype = onnx_dtype_to_torch_dtype (itype )
49+ else :
50+ assert itype , "The list cannot be created with an empty list."
51+ self .itype = np_dtype_to_tensor_dtype (itype [0 ].dtype )
52+ self .extend (itype )
53+ self .dtype = itype [0 ].dtype
54+ self .shape = "OnnxList"
55+
56+ def get_device (self ):
57+ assert len (self ) > 0 , "Cannot access the device for an empty list."
58+ return self [0 ].get_device () if hasattr (self [0 ], "get_device" ) else - 1
59+
60+ def numpy (self ):
61+ if all (isinstance (v , np .ndarray ) for v in self ):
62+ return self
63+ res = OnnxList ()
64+ for v in self :
65+ res .append (v .detach ().cpu ().numpy ())
66+ return res
67+
68+
3469class OnnxruntimeEvaluator :
3570 """
3671 This class loads an onnx model and the executes one by one the nodes
@@ -209,6 +244,8 @@ def output_types(self) -> List[TypeProto]:
209244 def _log_arg (self , a : Any ) -> Any :
210245 if isinstance (a , (str , int , float )):
211246 return a
247+ if isinstance (a , OnnxList ):
248+ return f"#{ len (a )} []"
212249 device = f"D{ a .get_device ()} :" if hasattr (a , "detach" ) else ""
213250 if hasattr (a , "shape" ):
214251 prefix = "A:" if hasattr (a , "astype" ) else "T:"
@@ -231,6 +268,12 @@ def _log(self, level: int, pattern: str, *args: Any) -> None:
231268 def _is_local_function (self , node : NodeProto ) -> bool :
232269 return (node .domain , node .op_type ) in self .local_functions
233270
271+ def _run_init (self , feed_inputs ):
272+ if self .sess_ is None :
273+ assert self .proto , "self.proto is empty"
274+ _ , self .sess_ = self ._get_sess (self .proto , list (feed_inputs .values ()))
275+ return self .sess_
276+
234277 def run (
235278 self ,
236279 outputs : Optional [List [str ]],
@@ -254,9 +297,7 @@ def run(
254297 """
255298 if self .rt_nodes_ is None :
256299 # runs a whole
257- if self .sess_ is None :
258- assert self .proto , "self.proto is empty"
259- _ , self .sess_ = self ._get_sess (self .proto , list (feed_inputs .values ()))
300+ self ._run_init (feed_inputs )
260301 assert self .sess_ , "mypy not happy"
261302 return self .sess_ .run (outputs , feed_inputs )
262303 if outputs is None :
@@ -283,7 +324,7 @@ def run(
283324 if node .op_type == "If" and node .domain == "" :
284325 outputs = self ._run_if (node , inputs , results )
285326 elif node .op_type in {"Scan" , "Loop" } and node .domain == "" :
286- outputs = self ._run_scan (node , inputs , results )
327+ outputs = self ._run_scan_or_loop (node , inputs , results )
287328 elif self ._is_local_function (node ):
288329 outputs = self ._run_local (node , inputs , results )
289330 else :
@@ -472,6 +513,18 @@ def _get_sess(
472513 )
473514 ]
474515 prenodes = [] # type: ignore[var-annotated]
516+ elif node .op_type == "ConcatFromSequence" and node .domain == "" :
517+ # We force the type to be a boolean.
518+ vinputs = [
519+ oh .make_value_info (
520+ node .input [0 ],
521+ type_proto = oh .make_sequence_type_proto (
522+ oh .make_tensor_type_proto (elem_type = inputs [0 ].itype , shape = None )
523+ ),
524+ )
525+ ]
526+ voutputs = [oh .make_tensor_value_info (node .output [0 ], inputs [0 ].itype , None )]
527+ prenodes = [] # type: ignore[var-annotated]
475528 else :
476529 unique_names = set ()
477530 vinputs = []
@@ -535,7 +588,17 @@ def _get_sess_init_subgraph(
535588 if i == "" or i in unique_names :
536589 continue
537590 unique_names .add (i )
538- value = oh .make_tensor_value_info (i , dtype_to_tensor_dtype (it .dtype ), it .shape )
591+ if isinstance (it , OnnxList ):
592+ value = oh .make_value_info (
593+ i ,
594+ type_proto = oh .make_sequence_type_proto (
595+ oh .make_tensor_type_proto (
596+ elem_type = dtype_to_tensor_dtype (it .dtype ), shape = None
597+ )
598+ ),
599+ )
600+ else :
601+ value = oh .make_tensor_value_info (i , dtype_to_tensor_dtype (it .dtype ), it .shape )
539602 vinputs .append (value )
540603
541604 reduced_set = self ._get_hidden_inputs (g )
@@ -592,6 +655,14 @@ def _get_sess_local(
592655
593656 def _run (self , node : NodeProto , inputs : List [Any ], results : Dict [str , Any ]) -> List [Any ]:
594657 """Runs a node."""
658+ if node .op_type [0 ] == "S" :
659+ if node .op_type == "SequenceEmpty" :
660+ dtype = TensorProto .FLOAT
661+ for att in node .attribute :
662+ if att .name == "dtype" :
663+ dtype = att .i
664+ return [OnnxList (itype = dtype )]
665+
595666 types = [(None if a is None else (a .dtype , a .shape )) for a in inputs ]
596667 key = (id (node ), * types )
597668 if key in self ._cache :
@@ -609,6 +680,12 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
609680 continue
610681 feeds [i ] = val
611682 assert hasattr (sess , "run" ), f"Missing method run for type { type (sess )} "
683+
684+ if node .op_type [0 ] == "C" :
685+ if node .op_type == "ConcatFromSequence" :
686+ # sess._run_init(feeds )
687+ return list (sess .sess .run (None , self .feeds_to_numpy (feeds )))
688+
612689 outputs = list (sess .run (None , feeds ))
613690 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
614691 return outputs
@@ -636,7 +713,7 @@ def _run_if(
636713 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
637714 return outputs
638715
639- def _get_sess_scan (
716+ def _get_sess_scan_or_loop (
640717 self , node : NodeProto , branch : str , inputs : List [Any ], context : Dict [str , Any ]
641718 ) -> Tuple [ModelProto , "OnnxruntimeEvaluator" ]:
642719 g = None
@@ -671,7 +748,18 @@ def _get_sess_scan(
671748 )
672749 return onx , sess
673750
674- def _run_scan (
751+ def feeds_to_numpy (self , feeds ):
752+ new_feeds = {}
753+ for k , v in feeds .items ():
754+ if hasattr (v , "detach" ):
755+ new_feeds [k ] = v .detach ().cpu ().numpy ()
756+ elif isinstance (v , OnnxList ):
757+ new_feeds [k ] = v .numpy ()
758+ else :
759+ new_feeds [k ] = v
760+ return new_feeds
761+
762+ def _run_scan_or_loop (
675763 self , node : NodeProto , inputs : List [Any ], results : Dict [str , Any ]
676764 ) -> List [Any ]:
677765 """Runs a node Scan."""
@@ -682,10 +770,18 @@ def _run_scan(
682770 if key in self ._cache :
683771 sess = self ._cache [key ][1 ]
684772 else :
685- self ._cache [key ] = _onx , sess = self ._get_sess_scan (node , name , inputs , results )
773+ self ._cache [key ] = _onx , sess = self ._get_sess_scan_or_loop (
774+ node , name , inputs , results
775+ )
686776
687777 assert hasattr (sess , "run" ), f"Missing method run for type { type (sess )} "
688778 feeds = {name : results [name ] for name in sess .input_names }
779+ if node .op_type == "Loop" and any (isinstance (v , OnnxList ) for v in feeds .values ()):
780+ # This operator uses sequence. onnxruntime does not play well with sequence.
781+ sess ._run_init (feeds )
782+ outputs = sess .sess_ .sess .run (None , self .feeds_to_numpy (feeds ))
783+ return [(OnnxList (v ) if isinstance (v , list ) else v ) for v in outputs ]
784+
689785 outputs = sess .run (None , feeds )
690786 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
691787 return outputs
0 commit comments