|
39 | 39 | ) |
40 | 40 | from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher |
41 | 41 |
|
42 | | -from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( |
43 | | - QuantizerReporter, |
44 | | - SUPPORTED_QCONFIGS, |
45 | | - SUPPORTED_QSPECS, |
46 | | -) |
| 42 | +from executorch.backends.cortex_m.quantizer_reporter import QuantizerReporter |
47 | 43 |
|
48 | 44 | from torch._ops import OpOverload |
49 | 45 |
|
@@ -219,20 +215,28 @@ def get_symmetric_quantization_config( |
219 | 215 | bias_quantization_spec = _get_int32_bias_qspec |
220 | 216 |
|
221 | 217 | if is_dynamic: |
222 | | - quantization_config = TOSAQuantizationConfig( |
223 | | - act_quantization_spec, |
224 | | - None, |
225 | | - weight_quantization_spec, |
226 | | - bias_quantization_spec, |
227 | | - ) |
| 218 | + output_activation = None |
228 | 219 | else: |
229 | | - quantization_config = TOSAQuantizationConfig( |
230 | | - act_quantization_spec, |
231 | | - act_quantization_spec, |
232 | | - weight_quantization_spec, |
233 | | - bias_quantization_spec, |
234 | | - ) |
235 | | - return quantization_config |
| 220 | + output_activation = act_quantization_spec |
| 221 | + |
| 222 | + module_name = __name__.rsplit(".", maxsplit=1)[-1] |
| 223 | + label = ( |
| 224 | + f"{module_name}.get_symmetric_quantization_config(" |
| 225 | + f"per_channel={int(is_per_channel)}, " |
| 226 | + f"qat={int(is_qat)}, " |
| 227 | + f"dynamic={int(is_dynamic)}, " |
| 228 | + f"act_range=[{act_qmin}, {act_qmax}], " |
| 229 | + f"weight_range=[{weight_qmin}, {weight_qmax}]" |
| 230 | + ")" |
| 231 | + ) |
| 232 | + |
| 233 | + return TOSAQuantizationConfig( |
| 234 | + act_quantization_spec, |
| 235 | + output_activation, |
| 236 | + weight_quantization_spec, |
| 237 | + bias_quantization_spec, |
| 238 | + label, |
| 239 | + ) |
236 | 240 |
|
237 | 241 |
|
238 | 242 | @functools.lru_cache |
@@ -357,59 +361,32 @@ def get_symmetric_a16w8_quantization_config( |
357 | 361 | is_qat=is_qat, |
358 | 362 | is_dynamic=is_dynamic, |
359 | 363 | ) |
360 | | - # Replace activation quantization spec with 16-bit version |
| 364 | + |
361 | 365 | if is_dynamic: |
362 | | - quantization_config = TOSAQuantizationConfig( |
363 | | - act_quantization_spec, # 16-bit input activations |
364 | | - None, |
365 | | - base_config.weight, # 8-bit weights from base config |
366 | | - base_config.bias, # bias from base config |
367 | | - ) |
| 366 | + output_activation = None |
368 | 367 | else: |
369 | | - quantization_config = TOSAQuantizationConfig( |
370 | | - act_quantization_spec, # 16-bit input activations |
371 | | - act_quantization_spec, # 16-bit output activations |
372 | | - base_config.weight, # 8-bit weights from base config |
373 | | - base_config.bias, # bias from base config |
374 | | - ) |
375 | | - return quantization_config |
376 | | - |
| 368 | + output_activation = act_quantization_spec |
| 369 | + |
| 370 | + module_name = __name__.rsplit(".", maxsplit=1)[-1] |
| 371 | + label = ( |
| 372 | + f"{module_name}.get_symmetric_a16w8_quantization_config(" |
| 373 | + f"per_channel={int(is_per_channel)}, " |
| 374 | + f"qat={int(is_qat)}, " |
| 375 | + f"dynamic={int(is_dynamic)}, " |
| 376 | + f"act_range=[{act_quantization_spec.quant_min}, {act_quantization_spec.quant_max}], " |
| 377 | + f"weight_range=[{weight_qmin}, {weight_qmax}]" |
| 378 | + ")" |
| 379 | + ) |
377 | 380 |
|
378 | | -# Register supported quantization configs and qspecs in the reporter for human-readable reporting |
379 | | -# MLETORCH-1854: Temporary solution, refactor to automatically register these instead |
380 | | -_symmetric_a8w4_config_per_channel = get_symmetric_a8w4_quantization_config() |
381 | | -_symmetric_a8w8_config_per_channel = get_symmetric_quantization_config() |
382 | | -_symmetric_a16w8_config_per_channel = get_symmetric_a16w8_quantization_config() |
383 | | -_symmetric_a8w4_config_per_tensor = get_symmetric_a8w4_quantization_config( |
384 | | - is_per_channel=False |
385 | | -) |
386 | | -_symmetric_a8w8_config_per_tensor = get_symmetric_quantization_config( |
387 | | - is_per_channel=False |
388 | | -) |
389 | | -_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config( |
390 | | - is_per_channel=False |
391 | | -) |
392 | | -SUPPORTED_QCONFIGS.update( |
393 | | - { |
394 | | - _symmetric_a8w8_config_per_channel: f"{__name__}.get_symmetric_quantization_config(is_per_channel=True)", |
395 | | - _symmetric_a16w8_config_per_channel: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=True)", |
396 | | - _symmetric_a8w4_config_per_channel: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=True)", |
397 | | - _symmetric_a8w8_config_per_tensor: f"{__name__}.get_symmetric_quantization_config(is_per_channel=False)", |
398 | | - _symmetric_a16w8_config_per_tensor: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=False)", |
399 | | - _symmetric_a8w4_config_per_tensor: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=False)", |
400 | | - } |
401 | | -) |
| 381 | + # Replace activation quantization spec with 16-bit version |
| 382 | + return TOSAQuantizationConfig( |
| 383 | + act_quantization_spec, # 16-bit input activations |
| 384 | + output_activation, |
| 385 | + base_config.weight, # 8-bit weights from base config |
| 386 | + base_config.bias, # bias from base config |
| 387 | + label, |
| 388 | + ) |
402 | 389 |
|
403 | | -SUPPORTED_QSPECS.update( |
404 | | - { |
405 | | - _symmetric_a8w4_config_per_channel.get_weight_qspec(): "INT4_PER_CHANNEL_QSPEC", |
406 | | - _symmetric_a8w8_config_per_channel.get_weight_qspec(): "INT8_PER_CHANNEL_QSPEC", |
407 | | - _symmetric_a8w8_config_per_tensor.get_weight_qspec(): "INT8_PER_TENSOR_QSPEC", |
408 | | - _symmetric_a8w4_config_per_tensor.get_weight_qspec(): "INT4_PER_TENSOR_QSPEC", |
409 | | - _symmetric_a8w8_config_per_tensor.get_input_act_qspec(): "INT8_PER_TENSOR_QSPEC", |
410 | | - _symmetric_a16w8_config_per_tensor.get_input_act_qspec(): "INT16_PER_TENSOR_QSPEC", |
411 | | - } |
412 | | -) |
413 | 390 |
|
414 | 391 | NodeFilterType = Callable[[Node], bool] |
415 | 392 | """Type for a Node Filter used by annotators. |
|
0 commit comments