|
4 | 4 | from .datatypes import to_infinicore_dtype, to_torch_dtype |
5 | 5 |
|
6 | 6 |
|
7 | | -def create_infinicore_tensor(torch_tensor, device_str): |
8 | | - """Create infinicore tensor from PyTorch tensor""" |
9 | | - infini_device = infinicore.device(device_str, 0) |
10 | | - |
11 | | - return infinicore.from_blob( |
12 | | - torch_tensor.data_ptr(), |
13 | | - list(torch_tensor.shape), |
14 | | - dtype=to_infinicore_dtype(torch_tensor.dtype), |
15 | | - device=infini_device, |
16 | | - ) |
17 | | - |
18 | | - |
19 | 7 | def synchronize_device(torch_device): |
20 | 8 | """Device synchronization""" |
21 | 9 | if torch_device == "cuda": |
@@ -149,44 +137,95 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3 |
149 | 137 | return tolerance["atol"], tolerance["rtol"] |
150 | 138 |
|
151 | 139 |
|
152 | | -def compare_results( |
153 | | - infini_result, torch_result, dtype, config, device_str, tolerance_map=None |
154 | | -): |
| 140 | +def create_infinicore_tensor(torch_tensor, device_str): |
| 141 | + """Create infinicore tensor from PyTorch tensor""" |
| 142 | + infini_device = infinicore.device(device_str, 0) |
| 143 | + |
| 144 | + return infinicore.from_blob( |
| 145 | + torch_tensor.data_ptr(), |
| 146 | + list(torch_tensor.shape), |
| 147 | + dtype=to_infinicore_dtype(torch_tensor.dtype), |
| 148 | + device=infini_device, |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +def convert_infinicore_to_torch(infini_result, torch_reference): |
155 | 153 | """ |
156 | | - Compare infinicore result with PyTorch reference result |
| 154 | + Convert infinicore tensor to PyTorch tensor for comparison |
157 | 155 |
|
158 | 156 | Args: |
159 | 157 | infini_result: infinicore tensor result |
160 | | - torch_result: PyTorch tensor reference result |
| 158 | + torch_reference: PyTorch tensor reference (for shape and device) |
161 | 159 | dtype: infinicore data type |
162 | | - config: test config |
163 | 160 | device_str: torch device string |
164 | | - device: device enum |
165 | | - tolerance_map: optional tolerance map (defaults to config's tolerance_map) |
166 | 161 |
|
167 | 162 | Returns: |
168 | | - bool: True if results match within tolerance |
| 163 | + torch.Tensor: PyTorch tensor with infinicore data |
169 | 164 | """ |
170 | | - # Convert infinicore result to PyTorch tensor for comparison |
171 | 165 | torch_result_from_infini = torch.zeros( |
172 | | - torch_result.shape, dtype=to_torch_dtype(dtype), device=device_str |
| 166 | + torch_reference.shape, |
| 167 | + dtype=to_torch_dtype(infini_result.dtype), |
| 168 | + device=infini_result.device.type, |
| 169 | + ) |
| 170 | + temp_tensor = create_infinicore_tensor( |
| 171 | + torch_result_from_infini, infini_result.device.type |
173 | 172 | ) |
174 | | - temp_tensor = create_infinicore_tensor(torch_result_from_infini, device_str) |
175 | 173 | temp_tensor.copy_(infini_result) |
| 174 | + return torch_result_from_infini |
176 | 175 |
|
177 | | - # Retrieve tolerance - use provided map or config's map |
178 | | - if tolerance_map is None: |
179 | | - tolerance_map = config.tolerance_map |
180 | | - atol, rtol = get_tolerance(tolerance_map, dtype) |
| 176 | + |
| 177 | +def compare_results( |
| 178 | + infini_result, torch_result, atol=1e-5, rtol=1e-5, debug_mode=False |
| 179 | +): |
| 180 | + """ |
| 181 | + Generic function to compare infinicore result with PyTorch reference result |
| 182 | +
|
| 183 | + Args: |
| 184 | + infini_result: infinicore tensor result |
| 185 | + torch_result: PyTorch tensor reference result |
| 186 | + atol: absolute tolerance |
| 187 | + rtol: relative tolerance |
| 188 | + debug_mode: whether to enable debug output |
| 189 | +
|
| 190 | + Returns: |
| 191 | + bool: True if results match within tolerance |
| 192 | + """ |
| 193 | + # Convert infinicore result to PyTorch tensor for comparison |
| 194 | + torch_result_from_infini = convert_infinicore_to_torch(infini_result, torch_result) |
181 | 195 |
|
182 | 196 | # Debug mode: detailed comparison |
183 | | - if config.debug: |
| 197 | + if debug_mode: |
184 | 198 | debug(torch_result_from_infini, torch_result, atol=atol, rtol=rtol) |
185 | 199 |
|
186 | 200 | # Check if results match within tolerance |
187 | 201 | return torch.allclose(torch_result_from_infini, torch_result, atol=atol, rtol=rtol) |
188 | 202 |
|
189 | 203 |
|
| 204 | +def create_test_comparator(config, dtype, tolerance_map=None): |
| 205 | + """ |
| 206 | + Create a test-specific comparison function that handles test configuration |
| 207 | +
|
| 208 | + Args: |
| 209 | + config: test configuration |
| 210 | + dtype: infinicore data type |
| 211 | + tolerance_map: optional tolerance map (defaults to config's tolerance_map) |
| 212 | +
|
| 213 | + Returns: |
| 214 | + callable: function that takes (infini_result, torch_result) and returns bool |
| 215 | + """ |
| 216 | + if tolerance_map is None: |
| 217 | + tolerance_map = config.tolerance_map |
| 218 | + |
| 219 | + atol, rtol = get_tolerance(tolerance_map, dtype) |
| 220 | + |
| 221 | + def compare_test_results(infini_result, torch_result): |
| 222 | + return compare_results( |
| 223 | + infini_result, torch_result, atol=atol, rtol=rtol, debug_mode=config.debug |
| 224 | + ) |
| 225 | + |
| 226 | + return compare_test_results |
| 227 | + |
| 228 | + |
190 | 229 | def rearrange_tensor(tensor, new_strides): |
191 | 230 | """ |
192 | 231 | Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides. |
|
0 commit comments