@@ -63,22 +63,39 @@ def __str__(self):
6363 if inp .init_mode != TensorInitializer .RANDOM
6464 else ""
6565 )
66- if hasattr (inp , "is_contiguous" ) and not inp .is_contiguous :
67- input_strs .append (f"strided_tensor{ inp .shape } { dtype_str } { init_str } " )
66+ # Show shape and strides for non-contiguous tensors
67+ if (
68+ hasattr (inp , "is_contiguous" )
69+ and not inp .is_contiguous
70+ and inp .strides
71+ ):
72+ strides_str = f", strides={ inp .strides } "
73+ input_strs .append (
74+ f"tensor{ inp .shape } { strides_str } { dtype_str } { init_str } "
75+ )
6876 else :
6977 input_strs .append (f"tensor{ inp .shape } { dtype_str } { init_str } " )
7078 else :
7179 input_strs .append (str (inp ))
7280
73- base_str = f"TestCase(mode={ mode_str } , inputs=[{ ', ' .join (input_strs )} ]"
81+ base_str = f"TestCase(mode={ mode_str } , inputs=[{ '; ' .join (input_strs )} ]"
7482 if self .output :
7583 dtype_str = f", dtype={ self .output .dtype } " if self .output .dtype else ""
7684 init_str = (
7785 f", init={ self .output .init_mode } "
7886 if self .output .init_mode != TensorInitializer .RANDOM
7987 else ""
8088 )
81- base_str += f", output=tensor{ self .output .shape } { dtype_str } { init_str } "
89+ # Show shape and strides for non-contiguous output tensors
90+ if (
91+ hasattr (self .output , "is_contiguous" )
92+ and not self .output .is_contiguous
93+ and self .output .strides
94+ ):
95+ strides_str = f", strides={ self .output .strides } "
96+ base_str += f", output=tensor{ self .output .shape } { strides_str } { dtype_str } { init_str } "
97+ else :
98+ base_str += f", output=tensor{ self .output .shape } { dtype_str } { init_str } "
8299 if self .kwargs :
83100 base_str += f", kwargs={ self .kwargs } "
84101 if self .description :
@@ -131,24 +148,30 @@ def run_tests(self, devices, test_func, test_type="Test"):
131148 if self .config .dtype_combinations :
132149 for dtype_combo in self .config .dtype_combinations :
133150 try :
134- test_func ( device , test_case , dtype_combo , self . config )
151+ # Print test case info first
135152 combo_str = self ._format_dtype_combo (dtype_combo )
136- print (f"✓ { test_case } with { combo_str } passed" )
153+ print (f"{ test_case } with { combo_str } " )
154+
155+ test_func (device , test_case , dtype_combo , self .config )
156+ print (f"\033 [92m✓\033 [0m Passed" )
137157 except Exception as e :
138158 combo_str = self ._format_dtype_combo (dtype_combo )
139- error_msg = f"{ test_case } with { combo_str } on { InfiniDeviceNames [ device ] } : { e } "
140- print (f"✗ { error_msg } " )
159+ error_msg = f"Error : { e } "
160+ print (f"\033 [91m✗ \033 [0m { error_msg } " )
141161 self .failed_tests .append (error_msg )
142162 if self .config .debug :
143163 raise
144164 else :
145165 for dtype in tensor_dtypes :
146166 try :
167+ # Print test case info first
168+ print (f"{ test_case } with { dtype } " )
169+
147170 test_func (device , test_case , dtype , self .config )
148- print (f"✓ { test_case } with { dtype } passed " )
171+ print (f"\033 [92m✓ \033 [0m Passed " )
149172 except Exception as e :
150- error_msg = f"{ test_case } with { dtype } on { InfiniDeviceNames [ device ] } : { e } "
151- print (f"✗ { error_msg } " )
173+ error_msg = f"Error : { e } "
174+ print (f"\033 [91m✗ \033 [0m { error_msg } " )
152175 self .failed_tests .append (error_msg )
153176 if self .config .debug :
154177 raise
@@ -214,7 +237,7 @@ def torch_operator(self, *inputs, out=None, **kwargs):
214237 raise NotImplementedError ("torch_operator not implemented" )
215238
216239 def infinicore_operator (self , * inputs , out = None , ** kwargs ):
217- """Unified Infinicore operator function - can be overridden or return None"""
240+ """Unified InfiniCore operator function - can be overridden or return None"""
218241 raise NotImplementedError ("infinicore_operator not implemented" )
219242
220243 def create_strided_tensor (
@@ -321,9 +344,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
321344
322345 # If neither operator is implemented, skip the test
323346 if not torch_implemented and not infini_implemented :
324- print (
325- f"⚠ { self .operator_name } { mode_name } : Both operators not implemented - test skipped"
326- )
347+ print (f"⚠ Both operators not implemented - test skipped" )
327348 return
328349
329350 # If only one operator is implemented, run it without comparison
@@ -332,7 +353,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
332353 "torch_operator" if not torch_implemented else "infinicore_operator"
333354 )
334355 print (
335- f"⚠ { self . operator_name } { mode_name } : { missing_op } not implemented - running single operator without comparison"
356+ f"⚠ { missing_op } not implemented - running single operator without comparison"
336357 )
337358
338359 # Run the available operator for benchmarking if requested
@@ -342,8 +363,9 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
342363 def torch_op ():
343364 return self .torch_operator (* inputs , ** kwargs )
344365
366+ print (f" { mode_name } :" )
345367 profile_operation (
346- f "PyTorch { self . operator_name } { mode_name } " ,
368+ "PyTorch " ,
347369 torch_op ,
348370 device_str ,
349371 config .num_prerun ,
@@ -354,8 +376,9 @@ def torch_op():
354376 def infini_op ():
355377 return self .infinicore_operator (* infini_inputs , ** kwargs )
356378
379+ print (f" { mode_name } :" )
357380 profile_operation (
358- f"Infinicore { self . operator_name } { mode_name } " ,
381+ "InfiniCore " ,
359382 infini_op ,
360383 device_str ,
361384 config .num_prerun ,
@@ -388,21 +411,22 @@ def infini_op():
388411 )
389412
390413 compare_fn = create_test_comparator (
391- config , comparison_dtype , mode_name = f"{ self . operator_name } { mode_name } "
414+ config , comparison_dtype , mode_name = f"{ mode_name } "
392415 )
393416 is_valid = compare_fn (infini_result , torch_result )
394- assert is_valid , f"{ self . operator_name } { mode_name } test failed"
417+ assert is_valid , f"{ mode_name } result comparison failed"
395418
396419 if config .bench :
420+ print (f" { mode_name } :" )
397421 profile_operation (
398- f "PyTorch { self . operator_name } { mode_name } " ,
422+ "PyTorch " ,
399423 torch_op ,
400424 device_str ,
401425 config .num_prerun ,
402426 config .num_iterations ,
403427 )
404428 profile_operation (
405- f"Infinicore { self . operator_name } { mode_name } " ,
429+ "InfiniCore " ,
406430 infini_op ,
407431 device_str ,
408432 config .num_prerun ,
@@ -464,21 +488,22 @@ def infini_op_inplace():
464488 test_case , dtype_config , torch_output
465489 )
466490 compare_fn = create_test_comparator (
467- config , comparison_dtype , mode_name = f"{ self . operator_name } { mode_name } "
491+ config , comparison_dtype , mode_name = f"{ mode_name } "
468492 )
469493 is_valid = compare_fn (infini_output , torch_output )
470- assert is_valid , f"{ self . operator_name } { mode_name } test failed"
494+ assert is_valid , f"{ mode_name } result comparison failed"
471495
472496 if config .bench :
497+ print (f" { mode_name } :" )
473498 profile_operation (
474- f "PyTorch { self . operator_name } { mode_name } " ,
499+ "PyTorch " ,
475500 torch_op_inplace ,
476501 device_str ,
477502 config .num_prerun ,
478503 config .num_iterations ,
479504 )
480505 profile_operation (
481- f"Infinicore { self . operator_name } { mode_name } " ,
506+ "InfiniCore " ,
482507 infini_op_inplace ,
483508 device_str ,
484509 config .num_prerun ,
0 commit comments