2121from executorch .backends .arm .common .annotation_meta import ArmAnnotationInfo
2222from executorch .backends .arm .constants import DISALLOW_TFA_META_KEY
2323from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
24+ from executorch .backends .cortex_m .quantizer_reporter import QuantizerReporterUser
2425from torch .fx import Node
2526
2627from torchao .quantization .pt2e .quantizer import (
@@ -160,25 +161,6 @@ def _get_int32_per_channel_bias_qspec(node):
160161 )
161162
162163
163- class _QuantizerReporterUserMixin :
164- def __init__ (self ):
165- self .reporter = None
166-
167- def register_reporter (self , reporter ) -> None :
168- self .reporter = reporter
169-
170- def report_reject (self , pattern : list [Node ], reason : str ) -> None :
171- if self .reporter is not None :
172- self .reporter .report_reject (self , pattern , reason )
173-
174- def report_accept (self , pattern : list [Node ]) -> None :
175- if self .reporter is not None :
176- self .reporter .report_accept (self , pattern )
177-
178- def get_quantizer_info (self ):
179- raise NotImplementedError ("Quantizer must implement get_quantizer_info method." )
180-
181-
182164class PatternCheck :
183165 """Base class for pattern checks.
184166
@@ -248,7 +230,7 @@ def find_nodes(self, model: torch.fx.GraphModule) -> Iterator[Node]:
248230 pass
249231
250232
251- class PatternQuantizer (Quantizer , _QuantizerReporterUserMixin ):
233+ class PatternQuantizer (Quantizer , QuantizerReporterUser ):
252234 """Quantizes a graph according to an OperatorConfig.
253235
254236 Args:
@@ -265,7 +247,7 @@ def __init__(
265247 pattern_matcher : "PatternMatcher" ,
266248 ) -> None :
267249 super ().__init__ ()
268- _QuantizerReporterUserMixin .__init__ (self )
250+ QuantizerReporterUser .__init__ (self )
269251 self .quantization_config : QuantizationConfig | None = quantization_config
270252 self .node_finder : "NodeFinder" = node_finder
271253 self .pattern_matcher : "PatternMatcher" = pattern_matcher
@@ -397,7 +379,7 @@ def validate(self, model: torch.fx.GraphModule) -> bool: # type: ignore[overrid
397379 return True
398380
399381
400- class SharedQspecQuantizer (Quantizer , _QuantizerReporterUserMixin ):
382+ class SharedQspecQuantizer (Quantizer , QuantizerReporterUser ):
401383 """Assures that specific ops share quantization parameters on all
402384 inputs/outputs.
403385 """
@@ -495,7 +477,7 @@ class SharedQspecQuantizer(Quantizer, _QuantizerReporterUserMixin):
495477
496478 def __init__ (self , targets : Optional [list [Callable [..., object ]]] = None ) -> None :
497479 super ().__init__ ()
498- _QuantizerReporterUserMixin .__init__ (self )
480+ QuantizerReporterUser .__init__ (self )
499481 if targets is None :
500482 self .targets = self .SHARED_QSPEC_OPS_DEFAULT
501483 self .support_config_path = (
0 commit comments