@@ -162,19 +162,6 @@ def create_strided_infinicore_tensor(torch_tensor, device_str):
162162 )
163163
164164
165- def is_tensor_contiguous (tensor ):
166- """Check if a tensor (PyTorch or infinicore) is contiguous"""
167- if hasattr (tensor , "is_contiguous" ):
168- return tensor .is_contiguous ()
169- elif hasattr (tensor , "stride" ):
170- # For PyTorch tensors
171- expected_stride = torch ._C ._compute_contiguous_strides (tensor .shape )
172- return list (tensor .stride ()) == expected_stride
173- else :
174- # Assume contiguous by default
175- return True
176-
177-
178165def convert_infinicore_to_torch (infini_result , torch_reference ):
179166 """
180167 Convert infinicore tensor to PyTorch tensor for comparison
@@ -193,9 +180,16 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
193180 dtype = to_torch_dtype (infini_result .dtype ),
194181 device = infini_result .device .type ,
195182 )
196- temp_tensor = create_infinicore_tensor (
197- torch_result_from_infini , infini_result .device .type
198- )
183+ if infini_result .is_contiguous ():
184+ temp_tensor = create_infinicore_tensor (
185+ torch_result_from_infini , infini_result .device .type
186+ )
187+ else :
188+ rearrange_tensor (torch_result_from_infini , list (torch_reference .stride ()))
189+ temp_tensor = create_strided_infinicore_tensor (
190+ torch_result_from_infini , infini_result .device .type
191+ )
192+
199193 temp_tensor .copy_ (infini_result )
200194 return torch_result_from_infini
201195
0 commit comments