Skip to content

Commit 30ac917

Browse files
committed
issue/497 - generalized test framework based on add
1 parent 7b79837 commit 30ac917

5 files changed

Lines changed: 325 additions & 104 deletions

File tree

test/infinicore/framework/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from .base import TensorSpec, TestConfig, TestRunner, TestCase, BaseOperatorTest
2+
from .parameter_mapping import (
3+
ParameterMapping,
4+
create_test_cases,
5+
)
26
from .utils import (
37
create_infinicore_tensor,
48
create_strided_infinicore_tensor,
@@ -21,6 +25,8 @@
2125
"TestRunner",
2226
"TestCase",
2327
"BaseOperatorTest",
28+
"ParameterMapping",
29+
"create_test_cases",
2430
"create_infinicore_tensor",
2531
"create_strided_infinicore_tensor",
2632
"compare_results",

test/infinicore/framework/base.py

Lines changed: 19 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -62,106 +62,31 @@ def from_strided_tensor(cls, shape, strides, dtype=None):
6262
class TestCase:
6363
"""Enhanced test case supporting flexible input/output specifications"""
6464

65-
def __init__(self, *args, **kwargs):
65+
def __init__(self, inputs, output=None, **kwargs):
6666
"""
67-
Flexible constructor supporting multiple input styles:
68-
69-
Style 1: Traditional tuple format (for backward compatibility)
70-
TestCase((2, 3), (3, 4), (2, 4), None, None, None)
71-
72-
Style 2: Explicit specification
73-
TestCase(
74-
inputs=[TensorSpec.from_tensor((2, 3)), TensorSpec.from_tensor((3, 4))],
75-
output=TensorSpec.from_tensor((2, 4))
76-
)
77-
78-
Style 3: Mixed format with description
79-
TestCase((2, 3), (3, 4), output=(2, 4), description="Basic matmul")
67+
简化构造函数
68+
Args:
69+
inputs: List[TensorSpec] 或简单的形状元组
70+
output: TensorSpec 或形状元组
8071
"""
81-
if args and isinstance(args[0], (list, tuple)) and len(args) >= 2:
82-
# Traditional tuple format: (a_shape, b_shape, result_shape, a_stride, b_stride, c_stride)
83-
if len(args) >= 3:
84-
self._init_from_tuples(*args, **kwargs)
72+
# 标准化 inputs
73+
self.inputs = []
74+
for inp in inputs:
75+
if isinstance(inp, (list, tuple)):
76+
self.inputs.append(TensorSpec.from_tensor(inp))
77+
elif isinstance(inp, TensorSpec):
78+
self.inputs.append(inp)
8579
else:
86-
self._init_from_mixed(args, kwargs)
87-
elif "inputs" in kwargs:
88-
# Explicit specification format
89-
self._init_from_explicit(kwargs)
90-
else:
91-
# Mixed format
92-
self._init_from_mixed(args, kwargs)
80+
self.inputs.append(inp)
9381

94-
def _init_from_tuples(
95-
self,
96-
a_shape,
97-
b_shape,
98-
result_shape=None,
99-
a_stride=None,
100-
b_stride=None,
101-
c_stride=None,
102-
**kwargs,
103-
):
104-
"""Initialize from traditional tuple format"""
105-
inputs = []
106-
107-
# First input
108-
if a_stride is not None:
109-
inputs.append(TensorSpec.from_strided_tensor(a_shape, a_stride))
110-
else:
111-
inputs.append(TensorSpec.from_tensor(a_shape))
112-
113-
# Second input
114-
if b_stride is not None:
115-
inputs.append(TensorSpec.from_strided_tensor(b_shape, b_stride))
82+
# 标准化 output
83+
if isinstance(output, (list, tuple)):
84+
self.output = TensorSpec.from_tensor(output)
11685
else:
117-
inputs.append(TensorSpec.from_tensor(b_shape))
118-
119-
# Output (if provided)
120-
output = None
121-
if result_shape is not None:
122-
if c_stride is not None:
123-
output = TensorSpec.from_strided_tensor(result_shape, c_stride)
124-
else:
125-
output = TensorSpec.from_tensor(result_shape)
126-
127-
self.inputs = inputs
128-
self.output = output
129-
self.kwargs = {k: v for k, v in kwargs.items() if k not in ["description"]}
130-
self.description = kwargs.get("description", "")
86+
self.output = output
13187

132-
def _init_from_mixed(self, args, kwargs):
133-
"""Initialize from mixed positional and keyword arguments"""
134-
inputs = []
135-
for arg in args:
136-
if isinstance(arg, (list, tuple)):
137-
# Shape tuple
138-
inputs.append(TensorSpec.from_tensor(arg))
139-
elif isinstance(arg, TensorSpec):
140-
# Already a TensorSpec
141-
inputs.append(arg)
142-
else:
143-
# Scalar or other value
144-
inputs.append(arg)
145-
146-
self.inputs = inputs
147-
self.output = kwargs.get("output")
148-
if isinstance(self.output, (list, tuple)):
149-
self.output = TensorSpec.from_tensor(self.output)
150-
self.kwargs = {
151-
k: v for k, v in kwargs.items() if k not in ["output", "description"]
152-
}
153-
self.description = kwargs.get("description", "")
154-
155-
def _init_from_explicit(self, kwargs):
156-
"""Initialize from explicit specification"""
157-
self.inputs = kwargs.get("inputs", [])
158-
self.output = kwargs.get("output")
159-
self.kwargs = {
160-
k: v
161-
for k, v in kwargs.items()
162-
if k not in ["inputs", "output", "description"]
163-
}
164-
self.description = kwargs.get("description", "")
88+
self.kwargs = kwargs
89+
self.description = kwargs.pop("description", "")
16590

16691
def __str__(self):
16792
input_strs = []
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Flexible parameter mapping system that allows operators to define their own mapping rules
3+
"""
4+
5+
from .base import TestCase, TensorSpec
6+
7+
8+
class ParameterMapping:
9+
"""Base class for parameter mapping configurations"""
10+
11+
def __init__(self, operator_name, call_signature, input_rules, output_rules):
12+
"""
13+
Args:
14+
operator_name: Name of the operator
15+
call_signature: Function signature template, e.g., "matmul(a, b)" or "add(input, other)"
16+
input_rules: List of rules for mapping test case data to input specifications
17+
output_rules: Rules for mapping test case data to output specification
18+
"""
19+
self.operator_name = operator_name
20+
self.call_signature = call_signature
21+
self.input_rules = input_rules
22+
self.output_rules = output_rules
23+
24+
def map_test_case(self, test_case_data):
25+
"""Map test case data to TestCase object using defined rules"""
26+
# Normalize test case data to handle different input formats
27+
normalized_data = self._normalize_test_case_data(test_case_data)
28+
29+
inputs = []
30+
31+
# Process input rules
32+
for rule in self.input_rules:
33+
input_spec = self._apply_rule(rule, normalized_data)
34+
if input_spec is not None:
35+
inputs.append(input_spec)
36+
37+
# Process output rules
38+
output_spec = self._apply_rule(self.output_rules, normalized_data)
39+
40+
return TestCase(inputs=inputs, output=output_spec)
41+
42+
def _normalize_test_case_data(self, test_case_data):
43+
"""Normalize test case data to handle different input formats"""
44+
if not isinstance(test_case_data, (list, tuple)):
45+
return (test_case_data,)
46+
47+
# If the first element is a tuple (shape), and there's only one element
48+
# e.g., ((13, 4)) → this should be shape (13, 4)
49+
if len(test_case_data) == 1 and isinstance(test_case_data[0], (list, tuple)):
50+
shape = test_case_data[0]
51+
return (shape, None, None, None)
52+
# If it's a tuple of integers (single shape), e.g., (13, 4)
53+
elif all(isinstance(x, int) for x in test_case_data):
54+
return (test_case_data, None, None, None)
55+
else:
56+
return test_case_data
57+
58+
def _apply_rule(self, rule, test_case_data):
59+
"""Apply a single mapping rule to test case data"""
60+
if rule is None:
61+
return None
62+
63+
if callable(rule):
64+
# Rule is a function
65+
return rule(test_case_data)
66+
elif isinstance(rule, dict):
67+
# Rule is a dictionary with shape and stride specifications
68+
shape_rule = rule.get("shape")
69+
stride_rule = rule.get("stride")
70+
71+
if shape_rule is None:
72+
return None
73+
74+
# Get shape from test case data
75+
if callable(shape_rule):
76+
shape = shape_rule(test_case_data)
77+
else:
78+
shape = self._get_data_by_index(test_case_data, shape_rule)
79+
80+
if shape is None:
81+
return None
82+
83+
# Get stride from test case data
84+
stride = None
85+
if stride_rule is not None:
86+
if callable(stride_rule):
87+
stride = stride_rule(test_case_data)
88+
else:
89+
stride = self._get_data_by_index(test_case_data, stride_rule)
90+
91+
# Only create strided tensor if stride is provided and valid
92+
if self._is_valid_strides(stride):
93+
return TensorSpec.from_strided_tensor(shape, stride)
94+
else:
95+
return TensorSpec.from_tensor(shape)
96+
else:
97+
raise ValueError(f"Invalid rule format: {rule}")
98+
99+
def _get_data_by_index(self, test_case_data, index):
100+
"""Safely get data from test case by index, return None if index out of range or value is None"""
101+
if isinstance(index, int):
102+
if 0 <= index < len(test_case_data):
103+
value = test_case_data[index]
104+
# Return None if the value is explicitly None
105+
return value if value is not None else None
106+
else:
107+
return None
108+
else:
109+
# If index is not an integer, assume it's a fixed value
110+
return index
111+
112+
def _is_valid_strides(self, strides):
113+
"""Check if strides are valid (not None and have proper format)"""
114+
if strides is None:
115+
return False
116+
if isinstance(strides, (list, tuple)) and len(strides) > 0:
117+
# Check if all elements are integers (or can be converted to valid strides)
118+
return all(isinstance(s, int) for s in strides)
119+
return False
120+
121+
122+
def create_parameter_mapping(
123+
operator_name, call_signature, input_configs, output_config
124+
):
125+
"""
126+
Create a parameter mapping from configuration
127+
128+
Args:
129+
operator_name: Name of the operator
130+
call_signature: Function call signature
131+
input_configs: List of input configurations
132+
output_config: Output configuration
133+
134+
Example for matmul:
135+
input_configs = [
136+
{'shape': 0, 'stride': 3}, # a: shape from index 0, stride from index 3
137+
{'shape': 1, 'stride': 4} # b: shape from index 1, stride from index 4
138+
]
139+
output_config = {'shape': 2, 'stride': 5} # output: shape from index 2, stride from index 5
140+
141+
Example for add:
142+
input_configs = [
143+
{'shape': 0, 'stride': 1}, # input: shape from index 0, stride from index 1
144+
{'shape': 0, 'stride': 2} # other: shape from index 0, stride from index 2
145+
]
146+
output_config = {'shape': 0, 'stride': 3} # output: shape from index 0, stride from index 3
147+
"""
148+
return ParameterMapping(operator_name, call_signature, input_configs, output_config)
149+
150+
151+
def create_test_cases(test_case_data, parameter_mapping):
152+
"""
153+
Create test cases from data using specified parameter mapping
154+
155+
Args:
156+
test_case_data: List of test case specifications
157+
parameter_mapping: ParameterMapping instance or configuration tuple
158+
159+
Returns:
160+
List of TestCase objects
161+
"""
162+
if isinstance(parameter_mapping, (list, tuple)):
163+
# Unpack configuration: (operator_name, call_signature, input_configs, output_config)
164+
if len(parameter_mapping) == 4:
165+
operator_name, call_signature, input_configs, output_config = (
166+
parameter_mapping
167+
)
168+
parameter_mapping = create_parameter_mapping(
169+
operator_name, call_signature, input_configs, output_config
170+
)
171+
else:
172+
raise ValueError("Invalid parameter mapping configuration format")
173+
174+
test_cases = []
175+
for i, data in enumerate(test_case_data):
176+
if isinstance(data, TestCase):
177+
test_cases.append(data)
178+
else:
179+
try:
180+
test_case = parameter_mapping.map_test_case(data)
181+
test_cases.append(test_case)
182+
except Exception as e:
183+
print(f"Warning: Failed to map test case {i} data {data}: {e}")
184+
# Fallback: try to create TestCase directly
185+
if isinstance(data, (list, tuple)):
186+
test_cases.append(TestCase(*data))
187+
else:
188+
test_cases.append(TestCase(data))
189+
190+
return test_cases

0 commit comments

Comments
 (0)