66import logging
77
88import torch
9- from executorch .backends .cortex_m .quantizer .quantization_configs import (
10- INT8_ACTIVATION_PER_CHANNEL_QSPEC ,
11- INT8_WEIGHT_PER_TENSOR_QSPEC ,
12- )
139from executorch .backends .cortex_m .quantizer .quantizer import mark_node_as_annotated
1410from executorch .backends .cortex_m .quantizer_reporter import (
1511 logger as quantizer_logger ,
12+ qspec_repr ,
1613 QuantizerInfo ,
1714 QuantizerReport ,
1815 QuantizerReporter ,
1916 QuantizerReporterUser ,
2017)
2118from torch .export import export
19+ from torchao .quantization .pt2e import MinMaxObserver , PerChannelMinMaxObserver
20+ from torchao .quantization .pt2e .quantizer import (
21+ DerivedQuantizationSpec ,
22+ QuantizationSpec ,
23+ SharedQuantizationSpec ,
24+ )
25+
26+ INT8_WEIGHT_PER_TENSOR_QSPEC = QuantizationSpec (
27+ dtype = torch .int8 ,
28+ observer_or_fake_quant_ctr = MinMaxObserver ,
29+ qscheme = torch .per_tensor_symmetric ,
30+ quant_min = - 127 ,
31+ quant_max = 127 ,
32+ )
33+ INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec (
34+ dtype = torch .int8 ,
35+ observer_or_fake_quant_ctr = PerChannelMinMaxObserver ,
36+ qscheme = torch .per_channel_affine ,
37+ ch_axis = 0 ,
38+ )
2239
2340
2441class _TwoOpModule (torch .nn .Module ):
@@ -43,6 +60,74 @@ def get_quantizer_info(self) -> QuantizerInfo:
4360 )
4461
4562
63+ def test_qspec_repr_quantization_spec_with_range ():
64+ qspec = QuantizationSpec (
65+ torch .int8 ,
66+ MinMaxObserver ,
67+ quant_min = - 42 ,
68+ quant_max = 123 ,
69+ )
70+ assert qspec_repr (qspec ) == "QuantizationSpec(dtype=INT8, range=(-42,123))"
71+
72+
73+ def test_qspec_repr_quantization_spec_without_range ():
74+ qspec = QuantizationSpec (
75+ torch .int16 ,
76+ MinMaxObserver ,
77+ )
78+ assert qspec_repr (qspec ) == "QuantizationSpec(dtype=INT16)"
79+
80+
81+ def test_qspec_repr_quantization_spec_partial_range ():
82+ qspec = QuantizationSpec (
83+ torch .int16 ,
84+ MinMaxObserver ,
85+ quant_min = - 100 ,
86+ )
87+ assert qspec_repr (qspec ) == "QuantizationSpec(dtype=INT16, range=(-100,None))"
88+
89+
90+ def test_qspec_repr_shared_quantization_spec ():
91+ graph_module = _export_two_op_graph_module ()
92+ add_node = next (
93+ node
94+ for node in graph_module .graph .nodes
95+ if node .target == torch .ops .aten .add .Tensor
96+ )
97+ qspec = SharedQuantizationSpec (add_node )
98+
99+ assert qspec_repr (qspec ) == f"SharedQuantizationSpec(edge_or_node={ add_node } )"
100+
101+
102+ def test_qspec_repr_derived_quantization_spec ():
103+ graph_module = _export_two_op_graph_module ()
104+ x_node = next (node for node in graph_module .graph .nodes if node .name == "x" )
105+ y_node = next (node for node in graph_module .graph .nodes if node .name == "y" )
106+ add_node = next (
107+ node
108+ for node in graph_module .graph .nodes
109+ if node .target == torch .ops .aten .add .Tensor
110+ )
111+ derived_from = [(x_node , add_node ), (y_node , add_node )]
112+ qspec = DerivedQuantizationSpec (
113+ derived_from = derived_from ,
114+ derive_qparams_fn = lambda _ : (
115+ torch .tensor ([1.0 ]),
116+ torch .tensor ([0 ], dtype = torch .int32 ),
117+ ),
118+ dtype = torch .int32 ,
119+ )
120+
121+ assert (
122+ qspec_repr (qspec )
123+ == f"DerivedQuantizationSpec(derived_from={ derived_from } , dtype={ qspec .dtype } )"
124+ )
125+
126+
127+ def test_qspec_repr_none ():
128+ assert qspec_repr (None ) == "None"
129+
130+
46131def test_warning_log_level (caplog ):
47132 graph_module = _export_two_op_graph_module ()
48133
@@ -128,11 +213,11 @@ def test_debug_log_level(caplog):
128213 Rejected due to previous annotation: 0
129214 Rejected nodes: 0
130215
131- NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP
132- -- ----------- ------------------------------- ---------------------------------
133- ╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC NO_QSPEC
134- | y: NO_QSPEC
135- ╘ relu INT8_ACTIVATION_PER_CHANNEL_QSPEC
216+ NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP
217+ -- ----------- ------------------------------------------------- ----------------------------
218+ ╒ add x: QuantizationSpec(dtype=INT8, range=(-127,127)) None
219+ | y: None
220+ ╘ relu QuantizationSpec(dtype=INT8)
136221----------------------------------------------------------------------------------------------------
137222DummyQuantizer using dummy nodes
138223Annotating with dummy.config
0 commit comments