@@ -37,25 +37,52 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
3737 print (f" { desc } time: { elapsed * 1000 :6f} ms" )
3838
3939
40- def debug (actual , desired , atol = 0 , rtol = 1e-2 , equal_nan = False , verbose = True ):
40+ def is_integer_dtype (dtype ):
41+ """Check if dtype is integer type"""
42+ return dtype in [
43+ infinicore .int8 ,
44+ infinicore .int16 ,
45+ infinicore .int32 ,
46+ infinicore .int64 ,
47+ infinicore .uint8 ,
48+ ]
49+
50+
51+ def is_float_dtype (dtype ):
52+ """Check if dtype is floating point type"""
53+ return dtype in [infinicore .float16 , infinicore .float32 , infinicore .bfloat16 ]
54+
55+
56+ def debug (
57+ actual , desired , atol = 0 , rtol = 1e-2 , equal_nan = False , verbose = True , dtype = None
58+ ):
4159 """
4260 Debug function to compare two tensors and print differences
4361 """
62+ # Convert to float32 for bfloat16 comparison
4463 if actual .dtype == torch .bfloat16 or desired .dtype == torch .bfloat16 :
4564 actual = actual .to (torch .float32 )
4665 desired = desired .to (torch .float32 )
4766
48- print_discrepancy (actual , desired , atol , rtol , equal_nan , verbose )
67+ print_discrepancy (actual , desired , atol , rtol , equal_nan , verbose , dtype )
4968
50- import numpy as np
69+ # Use appropriate comparison based on dtype
70+ if dtype and is_integer_dtype (dtype ):
71+ # For integer types, require exact equality
72+ import numpy as np
5173
52- np .testing .assert_allclose (
53- actual .cpu (), desired .cpu (), rtol , atol , equal_nan , verbose = True
54- )
74+ np .testing .assert_array_equal (actual .cpu (), desired .cpu ())
75+ else :
76+ # For float types, use allclose
77+ import numpy as np
78+
79+ np .testing .assert_allclose (
80+ actual .cpu (), desired .cpu (), rtol , atol , equal_nan , verbose = True
81+ )
5582
5683
5784def print_discrepancy (
58- actual , expected , atol = 0 , rtol = 1e-3 , equal_nan = True , verbose = True
85+ actual , expected , atol = 0 , rtol = 1e-3 , equal_nan = True , verbose = True , dtype = None
5986):
6087 """Print detailed tensor differences"""
6188 if actual .shape != expected .shape :
@@ -69,13 +96,21 @@ def print_discrepancy(
6996 actual_isnan = torch .isnan (actual )
7097 expected_isnan = torch .isnan (expected )
7198
72- # Calculate difference mask
73- nan_mismatch = (
74- actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan
75- )
76- diff_mask = nan_mismatch | (
77- torch .abs (actual - expected ) > (atol + rtol * torch .abs (expected ))
78- )
99+ # Calculate difference mask based on dtype
100+ if dtype and is_integer_dtype (dtype ):
101+ # For integer types, exact equality required
102+ diff_mask = actual != expected
103+ else :
104+ # For float types, use tolerance-based comparison
105+ nan_mismatch = (
106+ actual_isnan ^ expected_isnan
107+ if equal_nan
108+ else actual_isnan | expected_isnan
109+ )
110+ diff_mask = nan_mismatch | (
111+ torch .abs (actual - expected ) > (atol + rtol * torch .abs (expected ))
112+ )
113+
79114 diff_indices = torch .nonzero (diff_mask , as_tuple = False )
80115 delta = actual - expected
81116
@@ -107,8 +142,9 @@ def add_color(text, color_code):
107142
108143 print (f" - Actual dtype: { actual .dtype } " )
109144 print (f" - Desired dtype: { expected .dtype } " )
110- print (f" - Atol: { atol } " )
111- print (f" - Rtol: { rtol } " )
145+ if not (dtype and is_integer_dtype (dtype )):
146+ print (f" - Atol: { atol } " )
147+ print (f" - Rtol: { rtol } " )
112148 print (
113149 f" - Mismatched elements: { len (diff_indices )} / { actual .numel ()} ({ len (diff_indices ) / actual .numel () * 100 } %)"
114150 )
@@ -130,6 +166,10 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
130166 """
131167 Get tolerance settings based on data type
132168 """
169+ # For integer types, return zero tolerance (exact match required)
170+ if is_integer_dtype (tensor_dtype ):
171+ return 0 , 0
172+
133173 tolerance = tolerance_map .get (
134174 tensor_dtype , {"atol" : default_atol , "rtol" : default_rtol }
135175 )
@@ -162,8 +202,6 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
162202 Args:
163203 infini_result: infinicore tensor result
164204 torch_reference: PyTorch tensor reference (for shape and device)
165- dtype: infinicore data type
166- device_str: torch device string
167205
168206 Returns:
169207 torch.Tensor: PyTorch tensor with infinicore data
@@ -179,7 +217,7 @@ def convert_infinicore_to_torch(infini_result, torch_reference):
179217
180218
181219def compare_results (
182- infini_result , torch_result , atol = 1e-5 , rtol = 1e-5 , debug_mode = False
220+ infini_result , torch_result , atol = 1e-5 , rtol = 1e-5 , debug_mode = False , dtype = None
183221):
184222 """
185223 Generic function to compare infinicore result with PyTorch reference result
@@ -190,19 +228,29 @@ def compare_results(
190228 atol: absolute tolerance
191229 rtol: relative tolerance
192230 debug_mode: whether to enable debug output
231+ dtype: infinicore data type for comparison logic
193232
194233 Returns:
195234 bool: True if results match within tolerance
196235 """
197236 # Convert infinicore result to PyTorch tensor for comparison
198237 torch_result_from_infini = convert_infinicore_to_torch (infini_result , torch_result )
199238
239+ # Choose comparison method based on dtype
240+ if dtype and is_integer_dtype (dtype ):
241+ # For integer types, require exact equality
242+ result = torch .equal (torch_result_from_infini , torch_result )
243+ else :
244+ # For float types, use tolerance-based comparison
245+ result = torch .allclose (
246+ torch_result_from_infini , torch_result , atol = atol , rtol = rtol
247+ )
248+
200249 # Debug mode: detailed comparison
201250 if debug_mode :
202- debug (torch_result_from_infini , torch_result , atol = atol , rtol = rtol )
251+ debug (torch_result_from_infini , torch_result , atol = atol , rtol = rtol , dtype = dtype )
203252
204- # Check if results match within tolerance
205- return torch .allclose (torch_result_from_infini , torch_result , atol = atol , rtol = rtol )
253+ return result
206254
207255
208256def create_test_comparator (config , dtype , tolerance_map = None , mode_name = "" ):
@@ -227,7 +275,12 @@ def compare_test_results(infini_result, torch_result):
227275 if config .debug and mode_name :
228276 print (f"\n \033 [94mDEBUG INFO - { mode_name } :\033 [0m" )
229277 return compare_results (
230- infini_result , torch_result , atol = atol , rtol = rtol , debug_mode = config .debug
278+ infini_result ,
279+ torch_result ,
280+ atol = atol ,
281+ rtol = rtol ,
282+ debug_mode = config .debug ,
283+ dtype = dtype ,
231284 )
232285
233286 return compare_test_results
0 commit comments