Skip to content

Commit 2e5b234

Browse files
issue/547 - improved test output (#550)
2 parents bf3395f + 991f534 commit 2e5b234

2 files changed

Lines changed: 54 additions & 29 deletions

File tree

test/infinicore/framework/base.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,39 @@ def __str__(self):
6363
if inp.init_mode != TensorInitializer.RANDOM
6464
else ""
6565
)
66-
if hasattr(inp, "is_contiguous") and not inp.is_contiguous:
67-
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}{init_str}")
66+
# Show shape and strides for non-contiguous tensors
67+
if (
68+
hasattr(inp, "is_contiguous")
69+
and not inp.is_contiguous
70+
and inp.strides
71+
):
72+
strides_str = f", strides={inp.strides}"
73+
input_strs.append(
74+
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
75+
)
6876
else:
6977
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
7078
else:
7179
input_strs.append(str(inp))
7280

73-
base_str = f"TestCase(mode={mode_str}, inputs=[{', '.join(input_strs)}]"
81+
base_str = f"TestCase(mode={mode_str}, inputs=[{'; '.join(input_strs)}]"
7482
if self.output:
7583
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
7684
init_str = (
7785
f", init={self.output.init_mode}"
7886
if self.output.init_mode != TensorInitializer.RANDOM
7987
else ""
8088
)
81-
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
89+
# Show shape and strides for non-contiguous output tensors
90+
if (
91+
hasattr(self.output, "is_contiguous")
92+
and not self.output.is_contiguous
93+
and self.output.strides
94+
):
95+
strides_str = f", strides={self.output.strides}"
96+
base_str += f", output=tensor{self.output.shape}{strides_str}{dtype_str}{init_str}"
97+
else:
98+
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
8299
if self.kwargs:
83100
base_str += f", kwargs={self.kwargs}"
84101
if self.description:
@@ -131,24 +148,30 @@ def run_tests(self, devices, test_func, test_type="Test"):
131148
if self.config.dtype_combinations:
132149
for dtype_combo in self.config.dtype_combinations:
133150
try:
134-
test_func(device, test_case, dtype_combo, self.config)
151+
# Print test case info first
135152
combo_str = self._format_dtype_combo(dtype_combo)
136-
print(f"✓ {test_case} with {combo_str} passed")
153+
print(f"{test_case} with {combo_str}")
154+
155+
test_func(device, test_case, dtype_combo, self.config)
156+
print(f"\033[92m✓\033[0m Passed")
137157
except Exception as e:
138158
combo_str = self._format_dtype_combo(dtype_combo)
139-
error_msg = f"{test_case} with {combo_str} on {InfiniDeviceNames[device]}: {e}"
140-
print(f" {error_msg}")
159+
error_msg = f"Error: {e}"
160+
print(f"\033[91m✗\033[0m {error_msg}")
141161
self.failed_tests.append(error_msg)
142162
if self.config.debug:
143163
raise
144164
else:
145165
for dtype in tensor_dtypes:
146166
try:
167+
# Print test case info first
168+
print(f"{test_case} with {dtype}")
169+
147170
test_func(device, test_case, dtype, self.config)
148-
print(f"{test_case} with {dtype} passed")
171+
print(f"\033[92m✓\033[0m Passed")
149172
except Exception as e:
150-
error_msg = f"{test_case} with {dtype} on {InfiniDeviceNames[device]}: {e}"
151-
print(f" {error_msg}")
173+
error_msg = f"Error: {e}"
174+
print(f"\033[91m✗\033[0m {error_msg}")
152175
self.failed_tests.append(error_msg)
153176
if self.config.debug:
154177
raise
@@ -214,7 +237,7 @@ def torch_operator(self, *inputs, out=None, **kwargs):
214237
raise NotImplementedError("torch_operator not implemented")
215238

216239
def infinicore_operator(self, *inputs, out=None, **kwargs):
217-
"""Unified Infinicore operator function - can be overridden or return None"""
240+
"""Unified InfiniCore operator function - can be overridden or return None"""
218241
raise NotImplementedError("infinicore_operator not implemented")
219242

220243
def create_strided_tensor(
@@ -321,9 +344,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
321344

322345
# If neither operator is implemented, skip the test
323346
if not torch_implemented and not infini_implemented:
324-
print(
325-
f"⚠ {self.operator_name} {mode_name}: Both operators not implemented - test skipped"
326-
)
347+
print(f"⚠ Both operators not implemented - test skipped")
327348
return
328349

329350
# If only one operator is implemented, run it without comparison
@@ -332,7 +353,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
332353
"torch_operator" if not torch_implemented else "infinicore_operator"
333354
)
334355
print(
335-
f"⚠ {self.operator_name} {mode_name}: {missing_op} not implemented - running single operator without comparison"
356+
f"⚠ {missing_op} not implemented - running single operator without comparison"
336357
)
337358

338359
# Run the available operator for benchmarking if requested
@@ -342,8 +363,9 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
342363
def torch_op():
343364
return self.torch_operator(*inputs, **kwargs)
344365

366+
print(f" {mode_name}:")
345367
profile_operation(
346-
f"PyTorch {self.operator_name} {mode_name}",
368+
"PyTorch ",
347369
torch_op,
348370
device_str,
349371
config.num_prerun,
@@ -354,8 +376,9 @@ def torch_op():
354376
def infini_op():
355377
return self.infinicore_operator(*infini_inputs, **kwargs)
356378

379+
print(f" {mode_name}:")
357380
profile_operation(
358-
f"Infinicore {self.operator_name} {mode_name}",
381+
"InfiniCore",
359382
infini_op,
360383
device_str,
361384
config.num_prerun,
@@ -388,21 +411,22 @@ def infini_op():
388411
)
389412

390413
compare_fn = create_test_comparator(
391-
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
414+
config, comparison_dtype, mode_name=f"{mode_name}"
392415
)
393416
is_valid = compare_fn(infini_result, torch_result)
394-
assert is_valid, f"{self.operator_name} {mode_name} test failed"
417+
assert is_valid, f"{mode_name} result comparison failed"
395418

396419
if config.bench:
420+
print(f" {mode_name}:")
397421
profile_operation(
398-
f"PyTorch {self.operator_name} {mode_name}",
422+
"PyTorch ",
399423
torch_op,
400424
device_str,
401425
config.num_prerun,
402426
config.num_iterations,
403427
)
404428
profile_operation(
405-
f"Infinicore {self.operator_name} {mode_name}",
429+
"InfiniCore",
406430
infini_op,
407431
device_str,
408432
config.num_prerun,
@@ -464,21 +488,22 @@ def infini_op_inplace():
464488
test_case, dtype_config, torch_output
465489
)
466490
compare_fn = create_test_comparator(
467-
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
491+
config, comparison_dtype, mode_name=f"{mode_name}"
468492
)
469493
is_valid = compare_fn(infini_output, torch_output)
470-
assert is_valid, f"{self.operator_name} {mode_name} test failed"
494+
assert is_valid, f"{mode_name} result comparison failed"
471495

472496
if config.bench:
497+
print(f" {mode_name}:")
473498
profile_operation(
474-
f"PyTorch {self.operator_name} {mode_name}",
499+
"PyTorch ",
475500
torch_op_inplace,
476501
device_str,
477502
config.num_prerun,
478503
config.num_iterations,
479504
)
480505
profile_operation(
481-
f"Infinicore {self.operator_name} {mode_name}",
506+
"InfiniCore",
482507
infini_op_inplace,
483508
device_str,
484509
config.num_prerun,

test/infinicore/framework/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
3434

3535
# Timed execution
3636
elapsed = timed_op(lambda: func(), num_iterations, torch_device)
37-
print(f" {desc} time: {elapsed * 1000 :6f} ms")
37+
print(f" {desc} time: {elapsed * 1000 :6f} ms")
3838

3939

4040
def is_integer_dtype(dtype):
@@ -157,7 +157,7 @@ def add_color(text, color_code):
157157
print(
158158
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
159159
)
160-
print("-" * total_width + "\n")
160+
print("-" * total_width)
161161

162162
return diff_indices
163163

@@ -273,7 +273,7 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
273273

274274
def compare_test_results(infini_result, torch_result):
275275
if config.debug and mode_name:
276-
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
276+
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
277277
return compare_results(
278278
infini_result,
279279
torch_result,

0 commit comments

Comments
 (0)