2323 to_array_extended ,
2424 np_dtype_to_tensor_dtype ,
2525)
26- from ..helpers .torch_helper import onnx_dtype_to_torch_dtype
26+ from ..helpers .torch_helper import onnx_dtype_to_torch_dtype , torch_dtype_to_onnx_dtype
2727from ..helpers .ort_session import (
2828 InferenceSessionForTorch ,
2929 InferenceSessionForNumpy ,
@@ -48,22 +48,38 @@ def __init__(self, itype: Union[list, int]):
4848 self .dtype = onnx_dtype_to_torch_dtype (itype )
4949 else :
5050 assert itype , "The list cannot be created with an empty list."
51- self .itype = np_dtype_to_tensor_dtype (itype [0 ].dtype )
51+ self .itype = (
52+ np_dtype_to_tensor_dtype (itype [0 ].dtype )
53+ if isinstance (itype [0 ], np .ndarray )
54+ else torch_dtype_to_onnx_dtype (itype [0 ].dtype )
55+ )
5256 self .extend (itype )
5357 self .dtype = itype [0 ].dtype
5458 self .shape = "OnnxList"
5559
5660 def get_device (self ):
61+ "Returns the device of the first tensor."
5762 assert len (self ) > 0 , "Cannot access the device for an empty list."
5863 return self [0 ].get_device () if hasattr (self [0 ], "get_device" ) else - 1
5964
6065 def numpy (self ):
66+ "Creates a new list with all tensors on numpy or self it is already the case."
6167 if all (isinstance (v , np .ndarray ) for v in self ):
6268 return self
63- res = OnnxList ()
64- for v in self :
65- res .append (v .detach ().cpu ().numpy ())
66- return res
69+ return OnnxList ([v .detach ().cpu ().numpy () for v in self ])
70+
71+ def to (self , tensor_like ) -> "OnnxList" :
72+ "Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`."
73+ if isinstance (tensor_like , np .ndarray ):
74+ return self
75+ import torch
76+
77+ return OnnxList (
78+ [
79+ torch .from_numpy (t ).to (tensor_like .device ) if isinstance (t , np .ndarray ) else t
80+ for t in self
81+ ]
82+ )
6783
6884
6985class OnnxruntimeEvaluator :
@@ -245,7 +261,7 @@ def _log_arg(self, a: Any) -> Any:
245261 if isinstance (a , (str , int , float )):
246262 return a
247263 if isinstance (a , OnnxList ):
248- return f"# { len (a )} []"
264+ return string_type (a )
249265 device = f"D{ a .get_device ()} :" if hasattr (a , "detach" ) else ""
250266 if hasattr (a , "shape" ):
251267 prefix = "A:" if hasattr (a , "astype" ) else "T:"
@@ -683,8 +699,12 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
683699
684700 if node .op_type [0 ] == "C" :
685701 if node .op_type == "ConcatFromSequence" :
686- # sess._run_init(feeds )
687- return list (sess .sess .run (None , self .feeds_to_numpy (feeds )))
702+ res = sess .sess .run (None , self .feeds_to_numpy (feeds )) # type: ignore[union-attr]
703+ if isinstance (inputs [0 ][0 ], np .ndarray ):
704+ return list (res )
705+ import torch
706+
707+ return [torch .from_numpy (r ).to (inputs [0 ][0 ].device ) for r in res ]
688708
689709 outputs = list (sess .run (None , feeds ))
690710 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
@@ -778,9 +798,12 @@ def _run_scan_or_loop(
778798 feeds = {name : results [name ] for name in sess .input_names }
779799 if node .op_type == "Loop" and any (isinstance (v , OnnxList ) for v in feeds .values ()):
780800 # This operator uses sequence. onnxruntime does not play well with sequence.
781- sess ._run_init (feeds )
782- outputs = sess .sess_ .sess .run (None , self .feeds_to_numpy (feeds ))
783- return [(OnnxList (v ) if isinstance (v , list ) else v ) for v in outputs ]
801+ sess ._run_init (feeds ) # type: ignore[union-attr]
802+ outputs = sess .sess_ .sess .run (None , self .feeds_to_numpy (feeds )) # type: ignore[union-attr]
803+ return [
804+ (OnnxList (v ).to (feeds [node .input [0 ]]) if isinstance (v , list ) else v )
805+ for v in outputs
806+ ]
784807
785808 outputs = sess .run (None , feeds )
786809 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
0 commit comments