66
77from .datatypes import to_torch_dtype , to_infinicore_dtype
88from .devices import InfiniDeviceNames , torch_device_map
9+ from .tensor import TensorSpec , TensorInitializer
910from .utils import (
1011 create_test_comparator ,
1112 infinicore_tensor_from_torch ,
1516)
1617
1718
18- class TensorSpec :
19- """Tensor specification supporting various input types and per-tensor dtype"""
20-
21- def __init__ (
22- self ,
23- shape = None ,
24- dtype = None ,
25- strides = None ,
26- value = None ,
27- is_scalar = False ,
28- is_contiguous = True ,
29- ):
30- self .shape = shape
31- self .dtype = dtype
32- self .strides = strides
33- self .value = value
34- self .is_scalar = is_scalar
35- self .is_contiguous = is_contiguous
36-
37- @classmethod
38- def from_tensor (cls , shape , dtype = None , strides = None , is_contiguous = True ):
39- return cls (
40- shape = shape ,
41- dtype = dtype ,
42- strides = strides ,
43- is_scalar = False ,
44- is_contiguous = is_contiguous ,
45- )
46-
47- @classmethod
48- def from_scalar (cls , value , dtype = None ):
49- return cls (value = value , dtype = dtype , is_scalar = True )
50-
51- @classmethod
52- def from_strided_tensor (cls , shape , strides , dtype = None ):
53- return cls (
54- shape = shape ,
55- dtype = dtype ,
56- strides = strides ,
57- is_scalar = False ,
58- is_contiguous = False ,
59- )
60-
61-
6219class TestCase :
6320 """Test case"""
6421
@@ -101,17 +58,27 @@ def __str__(self):
10158 input_strs .append (f"scalar({ inp .value } { dtype_str } )" )
10259 elif hasattr (inp , "shape" ):
10360 dtype_str = f", dtype={ inp .dtype } " if inp .dtype else ""
61+ init_str = (
62+ f", init={ inp .init_mode } "
63+ if inp .init_mode != TensorInitializer .RANDOM
64+ else ""
65+ )
10466 if hasattr (inp , "is_contiguous" ) and not inp .is_contiguous :
105- input_strs .append (f"strided_tensor{ inp .shape } { dtype_str } " )
67+ input_strs .append (f"strided_tensor{ inp .shape } { dtype_str } { init_str } " )
10668 else :
107- input_strs .append (f"tensor{ inp .shape } { dtype_str } " )
69+ input_strs .append (f"tensor{ inp .shape } { dtype_str } { init_str } " )
10870 else :
10971 input_strs .append (str (inp ))
11072
11173 base_str = f"TestCase(mode={ mode_str } , inputs=[{ ', ' .join (input_strs )} ]"
11274 if self .output :
11375 dtype_str = f", dtype={ self .output .dtype } " if self .output .dtype else ""
114- base_str += f", output=tensor{ self .output .shape } { dtype_str } "
76+ init_str = (
77+ f", init={ self .output .init_mode } "
78+ if self .output .init_mode != TensorInitializer .RANDOM
79+ else ""
80+ )
81+ base_str += f", output=tensor{ self .output .shape } { dtype_str } { init_str } "
11582 if self .kwargs :
11683 base_str += f", kwargs={ self .kwargs } "
11784 if self .description :
@@ -252,17 +219,14 @@ def infinicore_operator(self, *inputs, out=None, **kwargs):
252219 """Unified Infinicore operator function"""
253220 pass
254221
255- def create_strided_tensor (self , shape , strides , dtype , device_str ):
222+ def create_strided_tensor (
223+ self , shape , strides , dtype , device , init_mode = TensorInitializer .RANDOM
224+ ):
256225 """Create a non-contiguous tensor with specific strides"""
257- total_size = 1
258- for i in range (len (shape )):
259- total_size += (shape [i ] - 1 ) * abs (strides [i ])
260-
261- base_tensor = torch .rand (total_size , dtype = dtype , device = device_str )
262- strided_tensor = torch .as_strided (base_tensor , shape , strides )
263- return strided_tensor
226+ spec = TensorSpec .from_strided_tensor (shape , strides , dtype , init_mode )
227+ return spec .create_torch_tensor (device , dtype )
264228
265- def prepare_inputs (self , test_case , device_str , dtype_config ):
229+ def prepare_inputs (self , test_case , device , dtype_config ):
266230 """Prepare input data"""
267231 inputs = []
268232
@@ -271,49 +235,26 @@ def prepare_inputs(self, test_case, device_str, dtype_config):
271235 if input_spec .is_scalar :
272236 inputs .append (input_spec .value )
273237 else :
274- shape = input_spec .shape
275-
276- if input_spec .dtype is not None :
277- tensor_dtype = to_torch_dtype (input_spec .dtype )
278- elif (
279- isinstance (dtype_config , dict ) and f"input_{ i } " in dtype_config
280- ):
281- tensor_dtype = to_torch_dtype (dtype_config [f"input_{ i } " ])
282- elif isinstance (dtype_config , (list , tuple )) and i < len (
283- dtype_config
284- ):
285- tensor_dtype = to_torch_dtype (dtype_config [i ])
286- else :
287- tensor_dtype = to_torch_dtype (dtype_config )
288-
289- if input_spec .is_contiguous or input_spec .strides is None :
290- tensor = torch .rand (
291- shape , dtype = tensor_dtype , device = device_str
292- )
293- else :
294- tensor = self .create_strided_tensor (
295- shape , input_spec .strides , tensor_dtype , device_str
296- )
297-
238+ tensor = input_spec .create_torch_tensor (device , dtype_config , i )
298239 inputs .append (tensor )
299240 else :
300241 inputs .append (input_spec )
301242
302243 return inputs , test_case .kwargs
303244
304245 def get_output_dtype (self , test_case , dtype_config , torch_result = None ):
305- """Determine output dtype"""
246+ """Determine output dtype - returns infinicore dtype, not torch dtype """
306247 if test_case .output and test_case .output .dtype is not None :
307- return to_torch_dtype ( test_case .output .dtype )
248+ return test_case .output .dtype
308249 elif isinstance (dtype_config , dict ) and "output" in dtype_config :
309- return to_torch_dtype ( dtype_config ["output" ])
250+ return dtype_config ["output" ]
310251 elif torch_result is not None :
311- return torch_result .dtype
252+ return to_infinicore_dtype ( torch_result .dtype )
312253 else :
313254 if isinstance (dtype_config , (list , tuple )):
314- return to_torch_dtype ( dtype_config [0 ])
255+ return dtype_config [0 ]
315256 else :
316- return to_torch_dtype ( dtype_config )
257+ return dtype_config
317258
318259 def run_test (self , device , test_case , dtype_config , config ):
319260 """Unified test execution flow"""
@@ -350,7 +291,7 @@ def _run_single_test(self, device, test_case, dtype_config, config, mode_name):
350291 """Run a single test with specified operation mode"""
351292 device_str = torch_device_map [device ]
352293
353- inputs , kwargs = self .prepare_inputs (test_case , device_str , dtype_config )
294+ inputs , kwargs = self .prepare_inputs (test_case , device , dtype_config )
354295
355296 infini_inputs = []
356297 for inp in inputs :
@@ -378,8 +319,9 @@ def infini_op():
378319
379320 infini_result = infini_op ()
380321
381- comparison_dtype = to_infinicore_dtype (
382- self .get_output_dtype (test_case , dtype_config , torch_result )
322+ # Get comparison dtype (infinicore dtype)
323+ comparison_dtype = self .get_output_dtype (
324+ test_case , dtype_config , torch_result
383325 )
384326
385327 compare_fn = create_test_comparator (
@@ -408,26 +350,40 @@ def infini_op():
408350 if not test_case .output :
409351 raise ValueError ("IN_PLACE test requires output specification" )
410352
353+ # Get output dtype and create output tensor
411354 output_dtype = self .get_output_dtype (test_case , dtype_config )
412355 output_shape = test_case .output .shape
413356
357+ # Use TensorSpec to create output tensor with specified initialization mode
414358 if test_case .output .is_contiguous or test_case .output .strides is None :
415- torch_output = torch . zeros (
416- output_shape , dtype = output_dtype , device = device_str
359+ output_spec = TensorSpec . from_tensor (
360+ output_shape , output_dtype , init_mode = test_case . output . init_mode
417361 )
418362 else :
419- torch_output = self .create_strided_tensor (
420- output_shape , test_case .output .strides , output_dtype , device_str
363+ output_spec = TensorSpec .from_strided_tensor (
364+ output_shape ,
365+ test_case .output .strides ,
366+ output_dtype ,
367+ init_mode = test_case .output .init_mode ,
421368 )
369+
370+ torch_output = output_spec .create_torch_tensor (device , output_dtype )
371+
372+ # For non-contiguous tensors, we need to ensure zeros initialization
373+ if (
374+ not test_case .output .is_contiguous
375+ and test_case .output .strides is not None
376+ ):
422377 torch_output .zero_ ()
423378
424379 def torch_op_inplace ():
425380 self .torch_operator (* inputs , out = torch_output , ** kwargs )
426381
427382 torch_op_inplace ()
428383
384+ # Create infinicore output tensor
429385 torch_dummy = torch .zeros (
430- output_shape , dtype = output_dtype , device = device_str
386+ output_shape , dtype = to_torch_dtype ( output_dtype ) , device = device_str
431387 )
432388 if (
433389 not test_case .output .is_contiguous
@@ -441,8 +397,8 @@ def infini_op_inplace():
441397
442398 infini_op_inplace ()
443399
444- comparison_dtype = to_infinicore_dtype (
445- self . get_output_dtype ( test_case , dtype_config , torch_output )
400+ comparison_dtype = self . get_output_dtype (
401+ test_case , dtype_config , torch_output
446402 )
447403 compare_fn = create_test_comparator (
448404 config , comparison_dtype , mode_name = f"{ self .operator_name } { mode_name } "
0 commit comments