Skip to content

Commit efd4770

Browse files
committed
issue/497 - support non-contiguous tensors in result comparison
1 parent 30ac917 commit efd4770

2 files changed

Lines changed: 10 additions & 18 deletions

File tree

test/infinicore/framework/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
profile_operation,
1414
rearrange_tensor,
1515
convert_infinicore_to_torch,
16-
is_tensor_contiguous,
1716
)
1817
from .config import get_test_devices, get_args
1918
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
@@ -36,7 +35,6 @@
3635
"get_tolerance",
3736
"profile_operation",
3837
"rearrange_tensor",
39-
"is_tensor_contiguous",
4038
"get_test_devices",
4139
"get_args",
4240
"InfiniDeviceEnum",

test/infinicore/framework/utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,6 @@ def create_strided_infinicore_tensor(torch_tensor, device_str):
162162
)
163163

164164

165-
def is_tensor_contiguous(tensor):
166-
"""Check if a tensor (PyTorch or infinicore) is contiguous"""
167-
if hasattr(tensor, "is_contiguous"):
168-
return tensor.is_contiguous()
169-
elif hasattr(tensor, "stride"):
170-
# For PyTorch tensors
171-
expected_stride = torch._C._compute_contiguous_strides(tensor.shape)
172-
return list(tensor.stride()) == expected_stride
173-
else:
174-
# Assume contiguous by default
175-
return True
176-
177-
178165
def convert_infinicore_to_torch(infini_result, torch_reference):
179166
"""
180167
Convert infinicore tensor to PyTorch tensor for comparison
@@ -193,9 +180,16 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
193180
dtype=to_torch_dtype(infini_result.dtype),
194181
device=infini_result.device.type,
195182
)
196-
temp_tensor = create_infinicore_tensor(
197-
torch_result_from_infini, infini_result.device.type
198-
)
183+
if infini_result.is_contiguous():
184+
temp_tensor = create_infinicore_tensor(
185+
torch_result_from_infini, infini_result.device.type
186+
)
187+
else:
188+
rearrange_tensor(torch_result_from_infini, list(torch_reference.stride()))
189+
temp_tensor = create_strided_infinicore_tensor(
190+
torch_result_from_infini, infini_result.device.type
191+
)
192+
199193
temp_tensor.copy_(infini_result)
200194
return torch_result_from_infini
201195

0 commit comments

Comments
 (0)