Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions test/infinicore/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,39 @@ def __str__(self):
if inp.init_mode != TensorInitializer.RANDOM
else ""
)
if hasattr(inp, "is_contiguous") and not inp.is_contiguous:
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}{init_str}")
# Show shape and strides for non-contiguous tensors
if (
hasattr(inp, "is_contiguous")
and not inp.is_contiguous
and inp.strides
):
strides_str = f", strides={inp.strides}"
input_strs.append(
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
)
else:
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
else:
input_strs.append(str(inp))

base_str = f"TestCase(mode={mode_str}, inputs=[{', '.join(input_strs)}]"
base_str = f"TestCase(mode={mode_str}, inputs=[{'; '.join(input_strs)}]"
if self.output:
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
init_str = (
f", init={self.output.init_mode}"
if self.output.init_mode != TensorInitializer.RANDOM
else ""
)
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
# Show shape and strides for non-contiguous output tensors
if (
hasattr(self.output, "is_contiguous")
and not self.output.is_contiguous
and self.output.strides
):
strides_str = f", strides={self.output.strides}"
base_str += f", output=tensor{self.output.shape}{strides_str}{dtype_str}{init_str}"
else:
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
if self.kwargs:
base_str += f", kwargs={self.kwargs}"
if self.description:
Expand Down Expand Up @@ -131,24 +148,30 @@ def run_tests(self, devices, test_func, test_type="Test"):
if self.config.dtype_combinations:
for dtype_combo in self.config.dtype_combinations:
try:
test_func(device, test_case, dtype_combo, self.config)
# Print test case info first
combo_str = self._format_dtype_combo(dtype_combo)
print(f"✓ {test_case} with {combo_str} passed")
print(f"{test_case} with {combo_str}")

test_func(device, test_case, dtype_combo, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
combo_str = self._format_dtype_combo(dtype_combo)
error_msg = f"{test_case} with {combo_str} on {InfiniDeviceNames[device]}: {e}"
print(f" {error_msg}")
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
else:
for dtype in tensor_dtypes:
try:
# Print test case info first
print(f"{test_case} with {dtype}")

test_func(device, test_case, dtype, self.config)
print(f"✓ {test_case} with {dtype} passed")
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
error_msg = f"{test_case} with {dtype} on {InfiniDeviceNames[device]}: {e}"
print(f" {error_msg}")
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
Expand Down Expand Up @@ -214,7 +237,7 @@ def torch_operator(self, *inputs, out=None, **kwargs):
raise NotImplementedError("torch_operator not implemented")

def infinicore_operator(self, *inputs, out=None, **kwargs):
"""Unified Infinicore operator function - can be overridden or return None"""
"""Unified InfiniCore operator function - can be overridden or return None"""
raise NotImplementedError("infinicore_operator not implemented")

def create_strided_tensor(
Expand Down Expand Up @@ -321,9 +344,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):

# If neither operator is implemented, skip the test
if not torch_implemented and not infini_implemented:
print(
f"⚠ {self.operator_name} {mode_name}: Both operators not implemented - test skipped"
)
print(f"⚠ Both operators not implemented - test skipped")
return

# If only one operator is implemented, run it without comparison
Expand All @@ -332,7 +353,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
"torch_operator" if not torch_implemented else "infinicore_operator"
)
print(
f"⚠ {self.operator_name} {mode_name}: {missing_op} not implemented - running single operator without comparison"
f"⚠ {missing_op} not implemented - running single operator without comparison"
)

# Run the available operator for benchmarking if requested
Expand All @@ -342,8 +363,9 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
def torch_op():
return self.torch_operator(*inputs, **kwargs)

print(f" {mode_name}:")
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
Expand All @@ -354,8 +376,9 @@ def torch_op():
def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs)

print(f" {mode_name}:")
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
Expand Down Expand Up @@ -388,21 +411,22 @@ def infini_op():
)

compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
config, comparison_dtype, mode_name=f"{mode_name}"
)
is_valid = compare_fn(infini_result, torch_result)
assert is_valid, f"{self.operator_name} {mode_name} test failed"
assert is_valid, f"{mode_name} result comparison failed"

if config.bench:
print(f" {mode_name}:")
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
Expand Down Expand Up @@ -464,21 +488,22 @@ def infini_op_inplace():
test_case, dtype_config, torch_output
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
config, comparison_dtype, mode_name=f"{mode_name}"
)
is_valid = compare_fn(infini_output, torch_output)
assert is_valid, f"{self.operator_name} {mode_name} test failed"
assert is_valid, f"{mode_name} result comparison failed"

if config.bench:
print(f" {mode_name}:")
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
"PyTorch ",
torch_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
"InfiniCore",
infini_op_inplace,
device_str,
config.num_prerun,
Expand Down
6 changes: 3 additions & 3 deletions test/infinicore/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):

# Timed execution
elapsed = timed_op(lambda: func(), num_iterations, torch_device)
print(f" {desc} time: {elapsed * 1000 :6f} ms")
print(f" {desc} time: {elapsed * 1000 :6f} ms")


def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
Expand Down Expand Up @@ -121,7 +121,7 @@ def add_color(text, color_code):
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width + "\n")
print("-" * total_width)

return diff_indices

Expand Down Expand Up @@ -225,7 +225,7 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):

def compare_test_results(infini_result, torch_result):
if config.debug and mode_name:
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
return compare_results(
infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug
)
Expand Down