Skip to content

Commit 5b7ef9c

Browse files
wooway777ma-hang
authored andcommitted
issue/540 - support more dtypes in test framework
1 parent a5e20fc commit 5b7ef9c

3 files changed

Lines changed: 124 additions & 27 deletions

File tree

test/infinicore/framework/datatypes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@ def to_torch_dtype(infini_dtype):
1010
return torch.float32
1111
elif infini_dtype == infinicore.bfloat16:
1212
return torch.bfloat16
13+
elif infini_dtype == infinicore.int8:
14+
return torch.int8
15+
elif infini_dtype == infinicore.int16:
16+
return torch.int16
1317
elif infini_dtype == infinicore.int32:
1418
return torch.int32
1519
elif infini_dtype == infinicore.int64:
1620
return torch.int64
21+
elif infini_dtype == infinicore.uint8:
22+
return torch.uint8
1723
else:
1824
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
1925

@@ -26,9 +32,15 @@ def to_infinicore_dtype(torch_dtype):
2632
return infinicore.float16
2733
elif torch_dtype == torch.bfloat16:
2834
return infinicore.bfloat16
35+
elif torch_dtype == torch.int8:
36+
return infinicore.int8
37+
elif torch_dtype == torch.int16:
38+
return infinicore.int16
2939
elif torch_dtype == torch.int32:
3040
return infinicore.int32
3141
elif torch_dtype == torch.int64:
3242
return infinicore.int64
43+
elif torch_dtype == torch.uint8:
44+
return infinicore.uint8
3345
else:
3446
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")

test/infinicore/framework/tensor.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
2+
import infinicore
23
from pathlib import Path
34
from .datatypes import to_torch_dtype
45
from .devices import torch_device_map
6+
from .utils import is_integer_dtype
57

68

79
class TensorInitializer:
@@ -38,6 +40,10 @@ def create_tensor(
3840
torch_device_str = torch_device_map[device]
3941
torch_dtype = to_torch_dtype(dtype)
4042

43+
# Handle integer types differently for random initialization
44+
if mode == TensorInitializer.RANDOM and is_integer_dtype(dtype):
45+
mode = TensorInitializer.RANDINT # Use randint for integer types
46+
4147
# Handle strided tensors - calculate required storage size
4248
if strides is not None:
4349
# Calculate the required storage size for strided tensor
@@ -61,9 +67,22 @@ def create_tensor(
6167
storage_size, dtype=torch_dtype, device=torch_device_str
6268
)
6369
elif mode == TensorInitializer.RANDINT:
70+
# For integer types, use appropriate range
71+
if is_integer_dtype(dtype):
72+
if dtype == infinicore.uint8:
73+
low, high = 0, 256
74+
elif dtype == infinicore.int8:
75+
low, high = -128, 128
76+
elif dtype == infinicore.int16:
77+
low, high = -32768, 32768
78+
else: # int32, int64, uint32
79+
low, high = -1000, 1000
80+
else:
81+
low, high = -1000, 1000
82+
6483
base_tensor = torch.randint(
65-
-2000000000,
66-
2000000000,
84+
low,
85+
high,
6786
(storage_size,),
6887
dtype=torch_dtype,
6988
device=torch_device_str,
@@ -92,9 +111,22 @@ def create_tensor(
92111
elif mode == TensorInitializer.ONES:
93112
tensor = torch.ones(shape, dtype=torch_dtype, device=torch_device_str)
94113
elif mode == TensorInitializer.RANDINT:
114+
# For integer types, use appropriate range
115+
if is_integer_dtype(dtype):
116+
if dtype == infinicore.uint8:
117+
low, high = 0, 256
118+
elif dtype == infinicore.int8:
119+
low, high = -128, 128
120+
elif dtype == infinicore.int16:
121+
low, high = -32768, 32768
122+
else: # int32, int64, uint32
123+
low, high = -1000, 1000
124+
else:
125+
low, high = -1000, 1000
126+
95127
tensor = torch.randint(
96-
-2000000000,
97-
2000000000,
128+
low,
129+
high,
98130
shape,
99131
dtype=torch_dtype,
100132
device=torch_device_str,

test/infinicore/framework/utils.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,52 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
3737
print(f" {desc} time: {elapsed * 1000 :6f} ms")
3838

3939

40-
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
40+
def is_integer_dtype(dtype):
41+
"""Check if dtype is integer type"""
42+
return dtype in [
43+
infinicore.int8,
44+
infinicore.int16,
45+
infinicore.int32,
46+
infinicore.int64,
47+
infinicore.uint8,
48+
]
49+
50+
51+
def is_float_dtype(dtype):
52+
"""Check if dtype is floating point type"""
53+
return dtype in [infinicore.float16, infinicore.float32, infinicore.bfloat16]
54+
55+
56+
def debug(
57+
actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True, dtype=None
58+
):
4159
"""
4260
Debug function to compare two tensors and print differences
4361
"""
62+
# Convert to float32 for bfloat16 comparison
4463
if actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16:
4564
actual = actual.to(torch.float32)
4665
desired = desired.to(torch.float32)
4766

48-
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
67+
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose, dtype)
4968

50-
import numpy as np
69+
# Use appropriate comparison based on dtype
70+
if dtype and is_integer_dtype(dtype):
71+
# For integer types, require exact equality
72+
import numpy as np
5173

52-
np.testing.assert_allclose(
53-
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
54-
)
74+
np.testing.assert_array_equal(actual.cpu(), desired.cpu())
75+
else:
76+
# For float types, use allclose
77+
import numpy as np
78+
79+
np.testing.assert_allclose(
80+
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True
81+
)
5582

5683

5784
def print_discrepancy(
58-
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True
85+
actual, expected, atol=0, rtol=1e-3, equal_nan=True, verbose=True, dtype=None
5986
):
6087
"""Print detailed tensor differences"""
6188
if actual.shape != expected.shape:
@@ -69,13 +96,21 @@ def print_discrepancy(
6996
actual_isnan = torch.isnan(actual)
7097
expected_isnan = torch.isnan(expected)
7198

72-
# Calculate difference mask
73-
nan_mismatch = (
74-
actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
75-
)
76-
diff_mask = nan_mismatch | (
77-
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
78-
)
99+
# Calculate difference mask based on dtype
100+
if dtype and is_integer_dtype(dtype):
101+
# For integer types, exact equality required
102+
diff_mask = actual != expected
103+
else:
104+
# For float types, use tolerance-based comparison
105+
nan_mismatch = (
106+
actual_isnan ^ expected_isnan
107+
if equal_nan
108+
else actual_isnan | expected_isnan
109+
)
110+
diff_mask = nan_mismatch | (
111+
torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
112+
)
113+
79114
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
80115
delta = actual - expected
81116

@@ -107,8 +142,9 @@ def add_color(text, color_code):
107142

108143
print(f" - Actual dtype: {actual.dtype}")
109144
print(f" - Desired dtype: {expected.dtype}")
110-
print(f" - Atol: {atol}")
111-
print(f" - Rtol: {rtol}")
145+
if not (dtype and is_integer_dtype(dtype)):
146+
print(f" - Atol: {atol}")
147+
print(f" - Rtol: {rtol}")
112148
print(
113149
f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)"
114150
)
@@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
130166
"""
131167
Get tolerance settings based on data type
132168
"""
169+
# For integer types, return zero tolerance (exact match required)
170+
if is_integer_dtype(tensor_dtype):
171+
return 0, 0
172+
133173
tolerance = tolerance_map.get(
134174
tensor_dtype, {"atol": default_atol, "rtol": default_rtol}
135175
)
@@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
162202
Args:
163203
infini_result: infinicore tensor result
164204
torch_reference: PyTorch tensor reference (for shape and device)
165-
dtype: infinicore data type
166-
device_str: torch device string
167205
168206
Returns:
169207
torch.Tensor: PyTorch tensor with infinicore data
@@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
179217

180218

181219
def compare_results(
182-
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False
220+
infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False, dtype=None
183221
):
184222
"""
185223
Generic function to compare infinicore result with PyTorch reference result
@@ -190,19 +228,29 @@ def compare_results(
190228
atol: absolute tolerance
191229
rtol: relative tolerance
192230
debug_mode: whether to enable debug output
231+
dtype: infinicore data type for comparison logic
193232
194233
Returns:
195234
bool: True if results match within tolerance
196235
"""
197236
# Convert infinicore result to PyTorch tensor for comparison
198237
torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result)
199238

239+
# Choose comparison method based on dtype
240+
if dtype and is_integer_dtype(dtype):
241+
# For integer types, require exact equality
242+
result = torch.equal(torch_result_from_infini, torch_result)
243+
else:
244+
# For float types, use tolerance-based comparison
245+
result = torch.allclose(
246+
torch_result_from_infini, torch_result, atol=atol, rtol=rtol
247+
)
248+
200249
# Debug mode: detailed comparison
201250
if debug_mode:
202-
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
251+
debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol, dtype=dtype)
203252

204-
# Check if results match within tolerance
205-
return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol)
253+
return result
206254

207255

208256
def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
@@ -227,7 +275,12 @@ def compare_test_results(infini_result, torch_result):
227275
if config.debug and mode_name:
228276
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
229277
return compare_results(
230-
infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug
278+
infini_result,
279+
torch_result,
280+
atol=atol,
281+
rtol=rtol,
282+
debug_mode=config.debug,
283+
dtype=dtype,
231284
)
232285

233286
return compare_test_results

0 commit comments

Comments
 (0)