|
| 1 | +""" |
| 2 | +Flexible parameter mapping system that allows operators to define their own mapping rules |
| 3 | +""" |
| 4 | + |
| 5 | +from .base import TestCase, TensorSpec |
| 6 | + |
| 7 | + |
| 8 | +class ParameterMapping: |
| 9 | + """Base class for parameter mapping configurations""" |
| 10 | + |
| 11 | + def __init__(self, operator_name, call_signature, input_rules, output_rules): |
| 12 | + """ |
| 13 | + Args: |
| 14 | + operator_name: Name of the operator |
| 15 | + call_signature: Function signature template, e.g., "matmul(a, b)" or "add(input, other)" |
| 16 | + input_rules: List of rules for mapping test case data to input specifications |
| 17 | + output_rules: Rules for mapping test case data to output specification |
| 18 | + """ |
| 19 | + self.operator_name = operator_name |
| 20 | + self.call_signature = call_signature |
| 21 | + self.input_rules = input_rules |
| 22 | + self.output_rules = output_rules |
| 23 | + |
| 24 | + def map_test_case(self, test_case_data): |
| 25 | + """Map test case data to TestCase object using defined rules""" |
| 26 | + # Normalize test case data to handle different input formats |
| 27 | + normalized_data = self._normalize_test_case_data(test_case_data) |
| 28 | + |
| 29 | + inputs = [] |
| 30 | + |
| 31 | + # Process input rules |
| 32 | + for rule in self.input_rules: |
| 33 | + input_spec = self._apply_rule(rule, normalized_data) |
| 34 | + if input_spec is not None: |
| 35 | + inputs.append(input_spec) |
| 36 | + |
| 37 | + # Process output rules |
| 38 | + output_spec = self._apply_rule(self.output_rules, normalized_data) |
| 39 | + |
| 40 | + return TestCase(inputs=inputs, output=output_spec) |
| 41 | + |
| 42 | + def _normalize_test_case_data(self, test_case_data): |
| 43 | + """Normalize test case data to handle different input formats""" |
| 44 | + if not isinstance(test_case_data, (list, tuple)): |
| 45 | + return (test_case_data,) |
| 46 | + |
| 47 | + # If the first element is a tuple (shape), and there's only one element |
| 48 | + # e.g., ((13, 4)) → this should be shape (13, 4) |
| 49 | + if len(test_case_data) == 1 and isinstance(test_case_data[0], (list, tuple)): |
| 50 | + shape = test_case_data[0] |
| 51 | + return (shape, None, None, None) |
| 52 | + # If it's a tuple of integers (single shape), e.g., (13, 4) |
| 53 | + elif all(isinstance(x, int) for x in test_case_data): |
| 54 | + return (test_case_data, None, None, None) |
| 55 | + else: |
| 56 | + return test_case_data |
| 57 | + |
| 58 | + def _apply_rule(self, rule, test_case_data): |
| 59 | + """Apply a single mapping rule to test case data""" |
| 60 | + if rule is None: |
| 61 | + return None |
| 62 | + |
| 63 | + if callable(rule): |
| 64 | + # Rule is a function |
| 65 | + return rule(test_case_data) |
| 66 | + elif isinstance(rule, dict): |
| 67 | + # Rule is a dictionary with shape and stride specifications |
| 68 | + shape_rule = rule.get("shape") |
| 69 | + stride_rule = rule.get("stride") |
| 70 | + |
| 71 | + if shape_rule is None: |
| 72 | + return None |
| 73 | + |
| 74 | + # Get shape from test case data |
| 75 | + if callable(shape_rule): |
| 76 | + shape = shape_rule(test_case_data) |
| 77 | + else: |
| 78 | + shape = self._get_data_by_index(test_case_data, shape_rule) |
| 79 | + |
| 80 | + if shape is None: |
| 81 | + return None |
| 82 | + |
| 83 | + # Get stride from test case data |
| 84 | + stride = None |
| 85 | + if stride_rule is not None: |
| 86 | + if callable(stride_rule): |
| 87 | + stride = stride_rule(test_case_data) |
| 88 | + else: |
| 89 | + stride = self._get_data_by_index(test_case_data, stride_rule) |
| 90 | + |
| 91 | + # Only create strided tensor if stride is provided and valid |
| 92 | + if self._is_valid_strides(stride): |
| 93 | + return TensorSpec.from_strided_tensor(shape, stride) |
| 94 | + else: |
| 95 | + return TensorSpec.from_tensor(shape) |
| 96 | + else: |
| 97 | + raise ValueError(f"Invalid rule format: {rule}") |
| 98 | + |
| 99 | + def _get_data_by_index(self, test_case_data, index): |
| 100 | + """Safely get data from test case by index, return None if index out of range or value is None""" |
| 101 | + if isinstance(index, int): |
| 102 | + if 0 <= index < len(test_case_data): |
| 103 | + value = test_case_data[index] |
| 104 | + # Return None if the value is explicitly None |
| 105 | + return value if value is not None else None |
| 106 | + else: |
| 107 | + return None |
| 108 | + else: |
| 109 | + # If index is not an integer, assume it's a fixed value |
| 110 | + return index |
| 111 | + |
| 112 | + def _is_valid_strides(self, strides): |
| 113 | + """Check if strides are valid (not None and have proper format)""" |
| 114 | + if strides is None: |
| 115 | + return False |
| 116 | + if isinstance(strides, (list, tuple)) and len(strides) > 0: |
| 117 | + # Check if all elements are integers (or can be converted to valid strides) |
| 118 | + return all(isinstance(s, int) for s in strides) |
| 119 | + return False |
| 120 | + |
| 121 | + |
| 122 | +def create_parameter_mapping( |
| 123 | + operator_name, call_signature, input_configs, output_config |
| 124 | +): |
| 125 | + """ |
| 126 | + Create a parameter mapping from configuration |
| 127 | +
|
| 128 | + Args: |
| 129 | + operator_name: Name of the operator |
| 130 | + call_signature: Function call signature |
| 131 | + input_configs: List of input configurations |
| 132 | + output_config: Output configuration |
| 133 | +
|
| 134 | + Example for matmul: |
| 135 | + input_configs = [ |
| 136 | + {'shape': 0, 'stride': 3}, # a: shape from index 0, stride from index 3 |
| 137 | + {'shape': 1, 'stride': 4} # b: shape from index 1, stride from index 4 |
| 138 | + ] |
| 139 | + output_config = {'shape': 2, 'stride': 5} # output: shape from index 2, stride from index 5 |
| 140 | +
|
| 141 | + Example for add: |
| 142 | + input_configs = [ |
| 143 | + {'shape': 0, 'stride': 1}, # input: shape from index 0, stride from index 1 |
| 144 | + {'shape': 0, 'stride': 2} # other: shape from index 0, stride from index 2 |
| 145 | + ] |
| 146 | + output_config = {'shape': 0, 'stride': 3} # output: shape from index 0, stride from index 3 |
| 147 | + """ |
| 148 | + return ParameterMapping(operator_name, call_signature, input_configs, output_config) |
| 149 | + |
| 150 | + |
| 151 | +def create_test_cases(test_case_data, parameter_mapping): |
| 152 | + """ |
| 153 | + Create test cases from data using specified parameter mapping |
| 154 | +
|
| 155 | + Args: |
| 156 | + test_case_data: List of test case specifications |
| 157 | + parameter_mapping: ParameterMapping instance or configuration tuple |
| 158 | +
|
| 159 | + Returns: |
| 160 | + List of TestCase objects |
| 161 | + """ |
| 162 | + if isinstance(parameter_mapping, (list, tuple)): |
| 163 | + # Unpack configuration: (operator_name, call_signature, input_configs, output_config) |
| 164 | + if len(parameter_mapping) == 4: |
| 165 | + operator_name, call_signature, input_configs, output_config = ( |
| 166 | + parameter_mapping |
| 167 | + ) |
| 168 | + parameter_mapping = create_parameter_mapping( |
| 169 | + operator_name, call_signature, input_configs, output_config |
| 170 | + ) |
| 171 | + else: |
| 172 | + raise ValueError("Invalid parameter mapping configuration format") |
| 173 | + |
| 174 | + test_cases = [] |
| 175 | + for i, data in enumerate(test_case_data): |
| 176 | + if isinstance(data, TestCase): |
| 177 | + test_cases.append(data) |
| 178 | + else: |
| 179 | + try: |
| 180 | + test_case = parameter_mapping.map_test_case(data) |
| 181 | + test_cases.append(test_case) |
| 182 | + except Exception as e: |
| 183 | + print(f"Warning: Failed to map test case {i} data {data}: {e}") |
| 184 | + # Fallback: try to create TestCase directly |
| 185 | + if isinstance(data, (list, tuple)): |
| 186 | + test_cases.append(TestCase(*data)) |
| 187 | + else: |
| 188 | + test_cases.append(TestCase(data)) |
| 189 | + |
| 190 | + return test_cases |
0 commit comments