Skip to content

Commit e03b56a

Browse files
committed
issue/461 - slightly simplified infinicore test framework
1 parent 5b02f42 commit e03b56a

3 files changed

Lines changed: 10 additions & 16 deletions

File tree

test/infinicore/framework/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import infinicore
3-
from .devices import InfiniDeviceNames, torch_device_map
3+
from .devices import InfiniDeviceNames
44
from .utils import synchronize_device
55

66

test/infinicore/framework/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
import time
33
import infinicore
44
from .datatypes import to_infinicore_dtype, to_torch_dtype
5-
from .devices import torch_device_map
65

76

8-
def create_infinicore_tensor(torch_tensor, device_enum):
7+
def create_infinicore_tensor(torch_tensor, device_str):
98
"""Create infinicore tensor from PyTorch tensor"""
10-
device_str = torch_device_map[device_enum]
119
infini_device = infinicore.device(device_str, 0)
1210

1311
return infinicore.from_blob(
@@ -152,7 +150,7 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
152150

153151

154152
def compare_results(
155-
infini_result, torch_result, dtype, config, device_str, device, tolerance_map=None
153+
infini_result, torch_result, dtype, config, device_str, tolerance_map=None
156154
):
157155
"""
158156
Compare infinicore result with PyTorch reference result
@@ -173,7 +171,7 @@ def compare_results(
173171
torch_result_from_infini = torch.zeros(
174172
torch_result.shape, dtype=to_torch_dtype(dtype), device=device_str
175173
)
176-
temp_tensor = create_infinicore_tensor(torch_result_from_infini, device)
174+
temp_tensor = create_infinicore_tensor(torch_result_from_infini, device_str)
177175
temp_tensor.copy_(infini_result)
178176

179177
# Retrieve tolerance - use provided map or config's map

test/infinicore/op/matmul.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def torch_matmul():
8383
torch_result = torch_matmul()
8484

8585
# Create infinicore tensors
86-
infini_a = create_infinicore_tensor(torch_a, device)
87-
infini_b = create_infinicore_tensor(torch_b, device)
86+
infini_a = create_infinicore_tensor(torch_a, device_str)
87+
infini_b = create_infinicore_tensor(torch_b, device_str)
8888

8989
# Out-of-place matmul
9090
def infini_matmul():
@@ -93,9 +93,7 @@ def infini_matmul():
9393
infini_result = infini_matmul()
9494

9595
# Validate results using common method
96-
is_valid = compare_results(
97-
infini_result, torch_result, dtype, config, device_str, device
98-
)
96+
is_valid = compare_results(infini_result, torch_result, dtype, config, device_str)
9997
assert is_valid, "Matmul test failed"
10098

10199
# Performance test
@@ -152,8 +150,8 @@ def torch_matmul_inplace():
152150
torch_matmul_inplace()
153151

154152
# Create infinicore tensors
155-
infini_a = create_infinicore_tensor(torch_a, device)
156-
infini_b = create_infinicore_tensor(torch_b, device)
153+
infini_a = create_infinicore_tensor(torch_a, device_str)
154+
infini_b = create_infinicore_tensor(torch_b, device_str)
157155
infini_c = infinicore.empty(
158156
result_shape, dtype=dtype, device=infinicore.device(device_str, 0)
159157
)
@@ -166,9 +164,7 @@ def infini_matmul_inplace():
166164
infini_matmul_inplace()
167165

168166
# Validate results using common method
169-
is_valid = compare_results(
170-
infini_c, torch_preallocated, dtype, config, device_str, device
171-
)
167+
is_valid = compare_results(infini_c, torch_preallocated, dtype, config, device_str)
172168
assert is_valid, "In-place matmul test failed"
173169

174170
# Performance test

0 commit comments

Comments
 (0)