22import infinicore .dtype
33from infinicore .lib import _infinicore
44
5+ from .utils import to_infinicore_dtype
6+
57
68class Tensor :
7- def __init__ (self , underlying ):
9+ def __init__ (self , underlying , * , _torch_ref = None ):
810 """An internal method. Please do not use this directly."""
911
1012 self ._underlying = underlying
@@ -15,6 +17,8 @@ def __init__(self, underlying):
1517 self ._underlying .device
1618 )
1719
20+ self ._torch_ref = _torch_ref
21+
1822 @property
1923 def shape (self ):
2024 return self ._underlying .shape
@@ -86,6 +90,12 @@ def debug(self, filename=None):
8690 else :
8791 self ._underlying .debug (filename )
8892
93+ def __add__ (self , other ):
94+ return infinicore .add (self , other )
95+
96+ def __mul__ (self , other ):
97+ return infinicore .mul (self , other )
98+
8999
90100def empty (size , * , dtype = None , device = None , pin_memory = False ):
91101 return Tensor (
@@ -135,3 +145,17 @@ def strided_from_blob(data_ptr, size, strides, *, dtype=None, device=None):
135145 data_ptr , size , strides , dtype ._underlying , device ._underlying
136146 )
137147 )
148+
149+
150+ def from_torch (torch_tensor ) -> Tensor :
151+ infini_type = to_infinicore_dtype (torch_tensor .dtype )
152+ infini_device = infinicore .device (torch_tensor .device .type , 0 )
153+ return Tensor (
154+ _infinicore .from_blob (
155+ torch_tensor .data_ptr (),
156+ list (torch_tensor .shape ),
157+ dtype = infini_type ._underlying ,
158+ device = infini_device ._underlying ,
159+ ),
160+ torch_ref = torch_tensor ,
161+ )
0 commit comments