22import infinicore
33
44from abc import ABC , abstractmethod
5- from typing import List , Dict , Any , Tuple , Union
5+ from typing import List , Dict , Any , Tuple , Union , Callable , Optional
66
77from .datatypes import to_torch_dtype , to_infinicore_dtype
88from .devices import InfiniDeviceNames , torch_device_map
@@ -66,12 +66,12 @@ class TestCase:
6666
6767 def __init__ (self , inputs , output = None , ** kwargs ):
6868 """
69- 简化构造函数
69+ Simplified constructor
7070 Args:
71- inputs: List[TensorSpec] 或简单的形状元组
72- output: TensorSpec 或形状元组
71+ inputs: List[TensorSpec] or simple shape tuples
72+ output: TensorSpec or shape tuple
7373 """
74- # 标准化 inputs
74+ # Normalize inputs
7575 self .inputs = []
7676 for inp in inputs :
7777 if isinstance (inp , (list , tuple )):
@@ -81,7 +81,7 @@ def __init__(self, inputs, output=None, **kwargs):
8181 else :
8282 self .inputs .append (inp )
8383
84- # 标准化 output
84+ # Normalize output
8585 if isinstance (output , (list , tuple )):
8686 self .output = TensorSpec .from_tensor (output )
8787 else :
@@ -142,11 +142,11 @@ def __init__(self, test_cases, test_config):
142142 self .config = test_config
143143 self .failed_tests = [] # Track failures
144144
145- def run_tests (self , devices , test_func ):
145+ def run_tests (self , devices , test_func , test_type = "Test" ):
146146 """Run tests and track failures"""
147147 for device in devices :
148148 print (f"\n { '=' * 60 } " )
149- print (f"Testing on { InfiniDeviceNames [device ]} " )
149+ print (f"Testing { test_type } on { InfiniDeviceNames [device ]} " )
150150 print (f"{ '=' * 60 } " )
151151
152152 # Filter unsupported data types
@@ -213,15 +213,17 @@ def get_tolerance_map(self):
213213 """Return tolerance configuration"""
214214 pass
215215
216- @abstractmethod
217- def torch_operator (self , * inputs , ** kwargs ):
218- """PyTorch operator implementation"""
219- pass
216+ def has_out_of_place_test (self ):
217+ """Check if out-of-place test functions are defined"""
218+ return hasattr (self , "torch_operator_out_of_place" ) and hasattr (
219+ self , "infinicore_operator_out_of_place"
220+ )
220221
221- @abstractmethod
222- def infinicore_operator (self , * inputs , ** kwargs ):
223- """Infinicore operator implementation"""
224- pass
222+ def has_inplace_test (self ):
223+ """Check if in-place test functions are defined"""
224+ return hasattr (self , "torch_operator_inplace" ) and hasattr (
225+ self , "infinicore_operator_inplace"
226+ )
225227
226228 def create_strided_tensor (self , shape , strides , dtype , device_str ):
227229 """Create a non-contiguous tensor with specific strides"""
@@ -274,16 +276,19 @@ def prepare_inputs(self, test_case, device_str, dtype):
274276
275277 return inputs , test_case .kwargs
276278
277- def run_test (self , device , test_case , dtype , config ):
278- """Generic test execution flow with flexible inputs - output is always contiguous"""
279+ def run_out_of_place_test (self , device , test_case , dtype , config ):
280+ """Generic out-of-place test execution flow"""
281+ if not self .has_out_of_place_test ():
282+ raise NotImplementedError ("Out-of-place test functions not defined" )
283+
279284 device_str = torch_device_map [device ]
280285
281286 # Prepare inputs
282287 inputs , kwargs = self .prepare_inputs (test_case , device_str , dtype )
283288
284- # PyTorch reference result - output is always contiguous for out-of-place
289+ # PyTorch reference result
285290 def torch_op ():
286- return self .torch_operator (* inputs , ** kwargs )
291+ return self .torch_operator_out_of_place (* inputs , ** kwargs )
287292
288293 torch_result = torch_op ()
289294
@@ -310,36 +315,39 @@ def torch_op():
310315 else :
311316 infini_inputs .append (inp )
312317
313- # Infinicore result - output is always contiguous for out-of-place
318+ # Infinicore result
314319 def infini_op ():
315- return self .infinicore_operator (* infini_inputs , ** kwargs )
320+ return self .infinicore_operator_out_of_place (* infini_inputs , ** kwargs )
316321
317322 infini_result = infini_op ()
318323
319324 # Result comparison
320325 compare_fn = create_test_comparator (config , dtype )
321326 is_valid = compare_fn (infini_result , torch_result )
322- assert is_valid , f"{ self .operator_name } test failed"
327+ assert is_valid , f"{ self .operator_name } out-of-place test failed"
323328
324329 # Performance testing
325330 if config .bench :
326331 profile_operation (
327- f"PyTorch { self .operator_name } " ,
332+ f"PyTorch { self .operator_name } Out-of-place " ,
328333 torch_op ,
329334 device_str ,
330335 config .num_prerun ,
331336 config .num_iterations ,
332337 )
333338 profile_operation (
334- f"Infinicore { self .operator_name } " ,
339+ f"Infinicore { self .operator_name } Out-of-place " ,
335340 infini_op ,
336341 device_str ,
337342 config .num_prerun ,
338343 config .num_iterations ,
339344 )
340345
341346 def run_inplace_test (self , device , test_case , dtype , config ):
342- """Generic in-place operation test execution flow - supports strided output"""
347+ """Generic in-place operation test execution flow"""
348+ if not self .has_inplace_test ():
349+ raise NotImplementedError ("In-place test functions not defined" )
350+
343351 device_str = torch_device_map [device ]
344352
345353 # Prepare inputs and output
@@ -370,7 +378,7 @@ def run_inplace_test(self, device, test_case, dtype, config):
370378 torch_preallocated .zero_ ()
371379
372380 def torch_op_inplace ():
373- self .torch_operator (* inputs , out = torch_preallocated , ** kwargs )
381+ self .torch_operator_inplace (* inputs , out = torch_preallocated , ** kwargs )
374382
375383 torch_op_inplace ()
376384
@@ -392,19 +400,7 @@ def torch_op_inplace():
392400 else :
393401 infini_inputs .append (inp )
394402
395- # # Create infinicore output tensor
396- # if test_case.output.is_contiguous or test_case.output.strides is None:
397- # infini_output = infinicore.empty(
398- # output_shape, dtype=dtype, device=infinicore.device(device_str, 0)
399- # )
400- # else:
401- # infini_output = infinicore.strided_empty(
402- # output_shape,
403- # test_case.output.strides,
404- # dtype=dtype,
405- # device=infinicore.device(device_str, 0),
406- # )
407-
403+ # Create infinicore output tensor
408404 torch_dummy = torch .zeros (output_shape , dtype = output_dtype , device = device_str )
409405 if test_case .output .is_contiguous or test_case .output .strides is None :
410406 infini_output = create_infinicore_tensor (torch_dummy , device_str )
@@ -413,7 +409,9 @@ def torch_op_inplace():
413409 infini_output = create_strided_infinicore_tensor (torch_dummy , device_str )
414410
415411 def infini_op_inplace ():
416- self .infinicore_operator (* infini_inputs , out = infini_output , ** kwargs )
412+ self .infinicore_operator_inplace (
413+ * infini_inputs , out = infini_output , ** kwargs
414+ )
417415
418416 infini_op_inplace ()
419417
0 commit comments