Skip to content

Commit b29e7a8

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Adapt sym quantizer to ET (#18870)
Summary: # Summary This diff includes a symmetric 8 bit quantizer in ET. --- #hthtemplate Reviewed By: hsharma35 Differential Revision: D91777784
1 parent 9207001 commit b29e7a8

2 files changed

Lines changed: 69 additions & 8 deletions

File tree

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@
9797
observer_or_fake_quant_ctr=MinMaxObserver,
9898
)
9999

100+
wgt_qspec_sym8s_127 = QuantizationSpec(
101+
dtype=torch.int8,
102+
quant_min=-127,
103+
quant_max=127,
104+
qscheme=torch.per_tensor_symmetric,
105+
is_dynamic=False,
106+
observer_or_fake_quant_ctr=MinMaxObserver,
107+
)
108+
100109
bias_qspec: Optional[QuantizationSpec] = None
101110

102111
qconfig_A8W8 = QuantizationConfig(
@@ -161,11 +170,11 @@
161170
None,
162171
)
163172

164-
qconfig_A32W8sym = QuantizationConfig(
173+
qconfig_A32W8sym_127 = QuantizationConfig(
165174
input_activation=None,
166175
output_activation=None,
167-
weight=wgt_qspec_sym8s,
168-
bias=wgt_qspec_sym8s,
176+
weight=wgt_qspec_sym8s_127,
177+
bias=wgt_qspec_sym8s_127,
169178
)
170179

171180

@@ -417,13 +426,13 @@ class CadenceW8A32MixedQuantizer(CadenceQuantizer):
417426
def __init__(self) -> None:
418427
quantizers = []
419428
quantizers.append(
420-
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
429+
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym_127)
421430
)
422431
quantizers.append(
423-
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
432+
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym_127)
424433
)
425434
quantizers.append(
426-
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym)
435+
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym_127)
427436
)
428437
super().__init__(quantizers)
429438

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
CadenceWithLayerNormQuantizer,
2929
CadenceWithSoftmaxQuantizer,
3030
qconfig_A16,
31+
qconfig_A32W8sym_127,
3132
qconfig_A8W8,
3233
qconfig_A8W8sym,
3334
)
@@ -57,7 +58,6 @@
5758
# These should be explicitly justified when added.
5859
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
5960
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
60-
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
6161
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
6262
}
6363

@@ -248,7 +248,7 @@
248248
# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
249249
TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = {
250250
type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES
251-
}
251+
} | {CadenceW8A32MixedQuantizer} # tested via dedicated methods
252252

253253

254254
class QuantizerAnnotationTest(unittest.TestCase):
@@ -607,6 +607,38 @@ def _build_conv2d_relu_graph(
607607

608608
return gm, relu_nodes[0], conv2d_nodes[0]
609609

610+
def _build_w8a32_conv1d_graph(
611+
self,
612+
) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
613+
"""Build a graph with conv1d(input, weight, bias) for the W8A32 pattern.
614+
615+
The MixedW8A32ConvPattern requires exactly 3 args (input, weight, bias),
616+
no kwargs, weight shape with channels multiple of 4 and kernel_size==3,
617+
and input length equal to the kernel size (3).
618+
"""
619+
builder = GraphBuilder()
620+
# Input shape: (batch, in_channels, length) = (1, 4, 3) — channels and length match constraints.
621+
x = builder.placeholder("x", torch.randn(1, 4, 3))
622+
# Weight shape: (out_channels, in_channels, kernel_size) = (8, 4, 3).
623+
weight = builder.placeholder("weight", torch.randn(8, 4, 3))
624+
bias = builder.placeholder("bias", torch.randn(8))
625+
conv1d = builder.call_operator(
626+
op=torch.ops.aten.conv1d.default,
627+
args=(x, weight, bias),
628+
meta=NodeMetadata(
629+
{"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]}
630+
),
631+
)
632+
builder.output([conv1d])
633+
gm = builder.get_graph_module()
634+
635+
conv1d_nodes = gm.graph.find_nodes(
636+
op="call_function",
637+
target=torch.ops.aten.conv1d.default,
638+
)
639+
self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node")
640+
return gm, conv1d_nodes[0]
641+
610642
def _build_conv1d_relu_graph(
611643
self,
612644
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
@@ -742,6 +774,26 @@ def test_all_quantizers_have_annotation_tests(self) -> None:
742774
f"{untested_names}. Please add test cases or explicitly exclude them."
743775
)
744776

777+
def test_w8a32_mixed_conv1d_annotation(self) -> None:
778+
"""W8A32 conv1d: weight + bias are int8 sym, input/output stay fp32."""
779+
gm, conv_node = self._build_w8a32_conv1d_graph()
780+
CadenceW8A32MixedQuantizer().annotate(gm)
781+
782+
annotation: QuantizationAnnotation = conv_node.meta[Q_ANNOTATION_KEY]
783+
self.assertTrue(annotation._annotated)
784+
self.assertIsNone(annotation.output_qspec)
785+
786+
weight_node, bias_node = conv_node.args[1], conv_node.args[2]
787+
self.assertEqual(
788+
set(annotation.input_qspec_map.keys()), {weight_node, bias_node}
789+
)
790+
self.assertEqual(
791+
annotation.input_qspec_map[weight_node], qconfig_A32W8sym_127.weight
792+
)
793+
self.assertEqual(
794+
annotation.input_qspec_map[bias_node], qconfig_A32W8sym_127.bias
795+
)
796+
745797

746798
class QuantizerOpsPreserveTest(unittest.TestCase):
747799
def test_mixed_w8a32_ops_to_preserve(self) -> None:

0 commit comments

Comments
 (0)