Skip to content

Commit 623b3d5

Browse files
wooway777voltjia
authored andcommitted
issue/497 - simplified infinicore tensor creation from torch
1 parent 39aad83 commit 623b3d5

3 files changed

Lines changed: 29 additions & 57 deletions

File tree

test/infinicore/framework/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from .base import TensorSpec, TestConfig, TestRunner, TestCase, BaseOperatorTest
22
from .utils import (
3-
create_infinicore_tensor,
4-
create_strided_infinicore_tensor,
53
compare_results,
64
create_test_comparator,
75
debug,
86
get_tolerance,
7+
infinicore_tensor_from_torch,
98
profile_operation,
109
rearrange_tensor,
1110
convert_infinicore_to_torch,
@@ -24,17 +23,16 @@
2423
"BaseOperatorTest",
2524
"ParameterMapping",
2625
"create_test_cases",
27-
"create_infinicore_tensor",
28-
"create_strided_infinicore_tensor",
2926
"compare_results",
3027
"create_test_comparator",
3128
"convert_infinicore_to_torch",
3229
"debug",
30+
"get_args",
31+
"get_test_devices",
3332
"get_tolerance",
33+
"infinicore_tensor_from_torch",
3434
"profile_operation",
3535
"rearrange_tensor",
36-
"get_test_devices",
37-
"get_args",
3836
"InfiniDeviceEnum",
3937
"InfiniDeviceNames",
4038
"torch_device_map",

test/infinicore/framework/base.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from .datatypes import to_torch_dtype, to_infinicore_dtype
88
from .devices import InfiniDeviceNames, torch_device_map
99
from .utils import (
10-
create_infinicore_tensor,
11-
create_strided_infinicore_tensor,
1210
create_test_comparator,
11+
infinicore_tensor_from_torch,
1312
profile_operation,
1413
rearrange_tensor,
1514
synchronize_device,
@@ -356,16 +355,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
356355
infini_inputs = []
357356
for inp in inputs:
358357
if isinstance(inp, torch.Tensor):
359-
if not inp.is_contiguous():
360-
infini_tensor = infinicore.strided_from_blob(
361-
inp.data_ptr(),
362-
list(inp.shape),
363-
list(inp.stride()),
364-
dtype=to_infinicore_dtype(inp.dtype),
365-
device=infinicore.device(device_str, 0),
366-
)
367-
else:
368-
infini_tensor = create_infinicore_tensor(inp, device_str)
358+
infini_tensor = infinicore_tensor_from_torch(inp)
369359
infini_inputs.append(infini_tensor)
370360
else:
371361
infini_inputs.append(inp)
@@ -439,13 +429,12 @@ def torch_op_inplace():
439429
torch_dummy = torch.zeros(
440430
output_shape, dtype=output_dtype, device=device_str
441431
)
442-
if test_case.output.is_contiguous or test_case.output.strides is None:
443-
infini_output = create_infinicore_tensor(torch_dummy, device_str)
444-
else:
432+
if (
433+
not test_case.output.is_contiguous
434+
and not test_case.output.strides is None
435+
):
445436
rearrange_tensor(torch_dummy, list(torch_output.stride()))
446-
infini_output = create_strided_infinicore_tensor(
447-
torch_dummy, device_str
448-
)
437+
infini_output = infinicore_tensor_from_torch(torch_dummy)
449438

450439
def infini_op_inplace():
451440
self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs)

test/infinicore/framework/utils.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -136,29 +136,23 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
136136
return tolerance["atol"], tolerance["rtol"]
137137

138138

139-
def create_infinicore_tensor(torch_tensor, device_str):
140-
"""Create infinicore tensor from PyTorch tensor"""
141-
infini_device = infinicore.device(device_str, 0)
142-
143-
return infinicore.from_blob(
144-
torch_tensor.data_ptr(),
145-
list(torch_tensor.shape),
146-
dtype=to_infinicore_dtype(torch_tensor.dtype),
147-
device=infini_device,
148-
)
149-
150-
151-
def create_strided_infinicore_tensor(torch_tensor, device_str):
152-
"""Create strided infinicore tensor from PyTorch tensor"""
153-
infini_device = infinicore.device(device_str, 0)
154-
155-
return infinicore.strided_from_blob(
156-
torch_tensor.data_ptr(),
157-
list(torch_tensor.shape),
158-
list(torch_tensor.stride()),
159-
dtype=to_infinicore_dtype(torch_tensor.dtype),
160-
device=infini_device,
161-
)
139+
def infinicore_tensor_from_torch(torch_tensor):
140+
infini_device = infinicore.device(torch_tensor.device.type, 0)
141+
if torch_tensor.is_contiguous():
142+
return infinicore.from_blob(
143+
torch_tensor.data_ptr(),
144+
list(torch_tensor.shape),
145+
dtype=to_infinicore_dtype(torch_tensor.dtype),
146+
device=infini_device,
147+
)
148+
else:
149+
return infinicore.strided_from_blob(
150+
torch_tensor.data_ptr(),
151+
list(torch_tensor.shape),
152+
list(torch_tensor.stride()),
153+
dtype=to_infinicore_dtype(torch_tensor.dtype),
154+
device=infini_device,
155+
)
162156

163157

164158
def convert_infinicore_to_torch(infini_result, torch_reference):
@@ -179,16 +173,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
179173
dtype=to_torch_dtype(infini_result.dtype),
180174
device=infini_result.device.type,
181175
)
182-
if infini_result.is_contiguous():
183-
temp_tensor = create_infinicore_tensor(
184-
torch_result_from_infini, infini_result.device.type
185-
)
186-
else:
187-
rearrange_tensor(torch_result_from_infini, list(torch_reference.stride()))
188-
temp_tensor = create_strided_infinicore_tensor(
189-
torch_result_from_infini, infini_result.device.type
190-
)
191-
176+
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
192177
temp_tensor.copy_(infini_result)
193178
return torch_result_from_infini
194179

0 commit comments

Comments
 (0)