Skip to content

Commit 41e1bb2

Browse files
committed
issue/497 - simplified infinicore test functions
1 parent 6c05256 commit 41e1bb2

3 files changed

Lines changed: 83 additions & 34 deletions

File tree

test/infinicore/framework/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from .utils import (
33
create_infinicore_tensor,
44
compare_results,
5+
create_test_comparator,
56
debug,
67
get_tolerance,
78
profile_operation,
89
rearrange_tensor,
10+
convert_infinicore_to_torch,
911
)
1012
from .config import get_test_devices, get_args
1113
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
@@ -17,6 +19,8 @@
1719
"TestCase",
1820
"create_infinicore_tensor",
1921
"compare_results",
22+
"create_test_comparator",
23+
"convert_infinicore_to_torch",
2024
"debug",
2125
"get_tolerance",
2226
"profile_operation",

test/infinicore/framework/utils.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,6 @@
44
from .datatypes import to_infinicore_dtype, to_torch_dtype
55

66

7-
def create_infinicore_tensor(torch_tensor, device_str):
8-
"""Create infinicore tensor from PyTorch tensor"""
9-
infini_device = infinicore.device(device_str, 0)
10-
11-
return infinicore.from_blob(
12-
torch_tensor.data_ptr(),
13-
list(torch_tensor.shape),
14-
dtype=to_infinicore_dtype(torch_tensor.dtype),
15-
device=infini_device,
16-
)
17-
18-
197
def synchronize_device(torch_device):
208
"""Device synchronization"""
219
if torch_device == "cuda":
@@ -149,44 +137,95 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
149137
return tolerance["atol"], tolerance["rtol"]
150138

151139

152-
def compare_results(
153-
infini_result, torch_result, dtype, config, device_str, tolerance_map=None
154-
):
140+
def create_infinicore_tensor(torch_tensor, device_str):
141+
"""Create infinicore tensor from PyTorch tensor"""
142+
infini_device = infinicore.device(device_str, 0)
143+
144+
return infinicore.from_blob(
145+
torch_tensor.data_ptr(),
146+
list(torch_tensor.shape),
147+
dtype=to_infinicore_dtype(torch_tensor.dtype),
148+
device=infini_device,
149+
)
150+
151+
152+
def convert_infinicore_to_torch(infini_result, torch_reference):
155153
"""
156-
Compare infinicore result with PyTorch reference result
154+
Convert infinicore tensor to PyTorch tensor for comparison
157155
158156
Args:
159157
infini_result: infinicore tensor result
160-
torch_result: PyTorch tensor reference result
158+
torch_reference: PyTorch tensor reference (for shape and device)
161159
dtype: infinicore data type
162-
config: test config
163160
device_str: torch device string
164-
device: device enum
165-
tolerance_map: optional tolerance map (defaults to config's tolerance_map)
166161
167162
Returns:
168-
bool: True if results match within tolerance
163+
torch.Tensor: PyTorch tensor with infinicore data
169164
"""
170-
# Convert infinicore result to PyTorch tensor for comparison
171165
torch_result_from_infini = torch.zeros(
172-
torch_result.shape, dtype=to_torch_dtype(dtype), device=device_str
166+
torch_reference.shape,
167+
dtype=to_torch_dtype(infini_result.dtype),
168+
device=infini_result.device.type,
169+
)
170+
temp_tensor = create_infinicore_tensor(
171+
torch_result_from_infini, infini_result.device.type
173172
)
174-
temp_tensor = create_infinicore_tensor(torch_result_from_infini, device_str)
175173
temp_tensor.copy_(infini_result)
174+
return torch_result_from_infini
176175

177-
# Retrieve tolerance - use provided map or config's map
178-
if tolerance_map is None:
179-
tolerance_map = config.tolerance_map
180-
atol, rtol = get_tolerance(tolerance_map, dtype)
176+
177+
def compare_results(
178+
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
179+
):
180+
"""
181+
Generic function to compare infinicore result with PyTorch reference result
182+
183+
Args:
184+
infini_result: infinicore tensor result
185+
torch_result: PyTorch tensor reference result
186+
atol: absolute tolerance
187+
rtol: relative tolerance
188+
debug_mode: whether to enable debug output
189+
190+
Returns:
191+
bool: True if results match within tolerance
192+
"""
193+
# Convert infinicore result to PyTorch tensor for comparison
194+
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)
181195

182196
# Debug mode: detailed comparison
183-
if config.debug:
197+
if debug_mode:
184198
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
185199

186200
# Check if results match within tolerance
187201
return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
188202

189203

204+
def create_test_comparator(config, dtype, tolerance_map=None):
205+
"""
206+
Create a test-specific comparison function that handles test configuration
207+
208+
Args:
209+
config: test configuration
210+
dtype: infinicore data type
211+
tolerance_map: optional tolerance map (defaults to config's tolerance_map)
212+
213+
Returns:
214+
callable: function that takes (infini_result, torch_result) and returns bool
215+
"""
216+
if tolerance_map is None:
217+
tolerance_map = config.tolerance_map
218+
219+
atol, rtol = get_tolerance(tolerance_map, dtype)
220+
221+
def compare_test_results(infini_result, torch_result):
222+
return compare_results(
223+
infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug
224+
)
225+
226+
return compare_test_results
227+
228+
190229
def rearrange_tensor(tensor, new_strides):
191230
"""
192231
Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.

test/infinicore/op/matmul.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TestRunner,
1212
TestCase,
1313
create_infinicore_tensor,
14-
compare_results,
14+
create_test_comparator,
1515
get_args,
1616
get_test_devices,
1717
profile_operation,
@@ -92,8 +92,11 @@ def infini_matmul():
9292

9393
infini_result = infini_matmul()
9494

95-
# Validate results using common method
96-
is_valid = compare_results(infini_result, torch_result, dtype, config, device_str)
95+
# Create test-specific comparator
96+
compare_fn = create_test_comparator(config, dtype)
97+
98+
# Validate results using the test-specific comparator
99+
is_valid = compare_fn(infini_result, torch_result)
97100
assert is_valid, "Matmul test failed"
98101

99102
# Performance test
@@ -163,8 +166,11 @@ def infini_matmul_inplace():
163166
# Execute in-place operation
164167
infini_matmul_inplace()
165168

166-
# Validate results using common method
167-
is_valid = compare_results(infini_c, torch_preallocated, dtype, config, device_str)
169+
# Create test-specific comparator
170+
compare_fn = create_test_comparator(config, dtype)
171+
172+
# Validate results using the test-specific comparator
173+
is_valid = compare_fn(infini_c, torch_preallocated)
168174
assert is_valid, "In-place matmul test failed"
169175

170176
# Performance test

0 commit comments

Comments
 (0)