Skip to content

Commit 846d897

Browse files
committed
issue/497 - temporarily fixed strided tensor creation
1 parent efd4770 commit 846d897

1 file changed

Lines changed: 19 additions & 10 deletions

File tree

test/infinicore/framework/base.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from .devices import InfiniDeviceNames, torch_device_map
99
from .utils import (
1010
create_infinicore_tensor,
11+
create_strided_infinicore_tensor,
1112
create_test_comparator,
1213
profile_operation,
14+
rearrange_tensor,
1315
synchronize_device,
1416
)
1517

@@ -390,18 +392,25 @@ def torch_op_inplace():
390392
else:
391393
infini_inputs.append(inp)
392394

393-
# Create infinicore output tensor
395+
# # Create infinicore output tensor
396+
# if test_case.output.is_contiguous or test_case.output.strides is None:
397+
# infini_output = infinicore.empty(
398+
# output_shape, dtype=dtype, device=infinicore.device(device_str, 0)
399+
# )
400+
# else:
401+
# infini_output = infinicore.strided_empty(
402+
# output_shape,
403+
# test_case.output.strides,
404+
# dtype=dtype,
405+
# device=infinicore.device(device_str, 0),
406+
# )
407+
408+
torch_dummy = torch.zeros(output_shape, dtype=output_dtype, device=device_str)
394409
if test_case.output.is_contiguous or test_case.output.strides is None:
395-
infini_output = infinicore.empty(
396-
output_shape, dtype=dtype, device=infinicore.device(device_str, 0)
397-
)
410+
infini_output = create_infinicore_tensor(torch_dummy, device_str)
398411
else:
399-
infini_output = infinicore.strided_empty(
400-
output_shape,
401-
test_case.output.strides,
402-
dtype=dtype,
403-
device=infinicore.device(device_str, 0),
404-
)
412+
rearrange_tensor(torch_dummy, list(torch_preallocated.stride()))
413+
infini_output = create_strided_infinicore_tensor(torch_dummy, device_str)
405414

406415
def infini_op_inplace():
407416
self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs)

0 commit comments

Comments
 (0)