@@ -136,29 +136,23 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
136136 return tolerance ["atol" ], tolerance ["rtol" ]
137137
138138
139- def create_infinicore_tensor (torch_tensor , device_str ):
140- """Create infinicore tensor from PyTorch tensor"""
141- infini_device = infinicore .device (device_str , 0 )
142-
143- return infinicore .from_blob (
144- torch_tensor .data_ptr (),
145- list (torch_tensor .shape ),
146- dtype = to_infinicore_dtype (torch_tensor .dtype ),
147- device = infini_device ,
148- )
149-
150-
151- def create_strided_infinicore_tensor (torch_tensor , device_str ):
152- """Create strided infinicore tensor from PyTorch tensor"""
153- infini_device = infinicore .device (device_str , 0 )
154-
155- return infinicore .strided_from_blob (
156- torch_tensor .data_ptr (),
157- list (torch_tensor .shape ),
158- list (torch_tensor .stride ()),
159- dtype = to_infinicore_dtype (torch_tensor .dtype ),
160- device = infini_device ,
161- )
139+ def infinicore_tensor_from_torch (torch_tensor ):
140+ infini_device = infinicore .device (torch_tensor .device .type , 0 )
141+ if torch_tensor .is_contiguous ():
142+ return infinicore .from_blob (
143+ torch_tensor .data_ptr (),
144+ list (torch_tensor .shape ),
145+ dtype = to_infinicore_dtype (torch_tensor .dtype ),
146+ device = infini_device ,
147+ )
148+ else :
149+ return infinicore .strided_from_blob (
150+ torch_tensor .data_ptr (),
151+ list (torch_tensor .shape ),
152+ list (torch_tensor .stride ()),
153+ dtype = to_infinicore_dtype (torch_tensor .dtype ),
154+ device = infini_device ,
155+ )
162156
163157
164158def convert_infinicore_to_torch (infini_result , torch_reference ):
@@ -179,16 +173,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
179173 dtype = to_torch_dtype (infini_result .dtype ),
180174 device = infini_result .device .type ,
181175 )
182- if infini_result .is_contiguous ():
183- temp_tensor = create_infinicore_tensor (
184- torch_result_from_infini , infini_result .device .type
185- )
186- else :
187- rearrange_tensor (torch_result_from_infini , list (torch_reference .stride ()))
188- temp_tensor = create_strided_infinicore_tensor (
189- torch_result_from_infini , infini_result .device .type
190- )
191-
176+ temp_tensor = infinicore_tensor_from_torch (torch_result_from_infini )
192177 temp_tensor .copy_ (infini_result )
193178 return torch_result_from_infini
194179
0 commit comments