|
8 | 8 | from .devices import InfiniDeviceNames, torch_device_map |
9 | 9 | from .utils import ( |
10 | 10 | create_infinicore_tensor, |
| 11 | + create_strided_infinicore_tensor, |
11 | 12 | create_test_comparator, |
12 | 13 | profile_operation, |
| 14 | + rearrange_tensor, |
13 | 15 | synchronize_device, |
14 | 16 | ) |
15 | 17 |
|
@@ -390,18 +392,25 @@ def torch_op_inplace(): |
390 | 392 | else: |
391 | 393 | infini_inputs.append(inp) |
392 | 394 |
|
393 | | - # Create infinicore output tensor |
| 395 | + # # Create infinicore output tensor |
| 396 | + # if test_case.output.is_contiguous or test_case.output.strides is None: |
| 397 | + # infini_output = infinicore.empty( |
| 398 | + # output_shape, dtype=dtype, device=infinicore.device(device_str, 0) |
| 399 | + # ) |
| 400 | + # else: |
| 401 | + # infini_output = infinicore.strided_empty( |
| 402 | + # output_shape, |
| 403 | + # test_case.output.strides, |
| 404 | + # dtype=dtype, |
| 405 | + # device=infinicore.device(device_str, 0), |
| 406 | + # ) |
| 407 | + |
| 408 | + torch_dummy = torch.zeros(output_shape, dtype=output_dtype, device=device_str) |
394 | 409 | if test_case.output.is_contiguous or test_case.output.strides is None: |
395 | | - infini_output = infinicore.empty( |
396 | | - output_shape, dtype=dtype, device=infinicore.device(device_str, 0) |
397 | | - ) |
| 410 | + infini_output = create_infinicore_tensor(torch_dummy, device_str) |
398 | 411 | else: |
399 | | - infini_output = infinicore.strided_empty( |
400 | | - output_shape, |
401 | | - test_case.output.strides, |
402 | | - dtype=dtype, |
403 | | - device=infinicore.device(device_str, 0), |
404 | | - ) |
| 412 | + rearrange_tensor(torch_dummy, list(torch_preallocated.stride())) |
| 413 | + infini_output = create_strided_infinicore_tensor(torch_dummy, device_str) |
405 | 414 |
|
406 | 415 | def infini_op_inplace(): |
407 | 416 | self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs) |
|
0 commit comments