Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@
observer_or_fake_quant_ctr=MinMaxObserver,
)

wgt_qspec_sym8s_127 = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
)

bias_qspec: Optional[QuantizationSpec] = None

qconfig_A8W8 = QuantizationConfig(
Expand Down Expand Up @@ -161,11 +170,11 @@
None,
)

qconfig_A32W8sym = QuantizationConfig(
qconfig_A32W8sym_127 = QuantizationConfig(
input_activation=None,
output_activation=None,
weight=wgt_qspec_sym8s,
bias=wgt_qspec_sym8s,
weight=wgt_qspec_sym8s_127,
bias=wgt_qspec_sym8s_127,
)
Comment on lines +173 to 178


Expand Down Expand Up @@ -417,13 +426,13 @@ class CadenceW8A32MixedQuantizer(CadenceQuantizer):
def __init__(self) -> None:
quantizers = []
quantizers.append(
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym_127)
)
quantizers.append(
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym_127)
)
quantizers.append(
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym)
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym_127)
)
super().__init__(quantizers)

Expand Down
56 changes: 54 additions & 2 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CadenceWithLayerNormQuantizer,
CadenceWithSoftmaxQuantizer,
qconfig_A16,
qconfig_A32W8sym_127,
qconfig_A8W8,
qconfig_A8W8sym,
)
Expand Down Expand Up @@ -57,7 +58,6 @@
# These should be explicitly justified when added.
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
}

Expand Down Expand Up @@ -248,7 +248,7 @@
# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = {
type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES
}
} | {CadenceW8A32MixedQuantizer} # tested via dedicated methods


class QuantizerAnnotationTest(unittest.TestCase):
Expand Down Expand Up @@ -607,6 +607,38 @@ def _build_conv2d_relu_graph(

return gm, relu_nodes[0], conv2d_nodes[0]

def _build_w8a32_conv1d_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a graph with conv1d(input, weight, bias) for the W8A32 pattern.

The MixedW8A32ConvPattern requires exactly 3 args (input, weight, bias),
no kwargs, weight shape with channels multiple of 4 and kernel_size==3,
and input length equal to the kernel size (3).
"""
builder = GraphBuilder()
# Input shape: (batch, in_channels, length) = (1, 4, 3) — channels and length match constraints.
x = builder.placeholder("x", torch.randn(1, 4, 3))
# Weight shape: (out_channels, in_channels, kernel_size) = (8, 4, 3).
weight = builder.placeholder("weight", torch.randn(8, 4, 3))
bias = builder.placeholder("bias", torch.randn(8))
conv1d = builder.call_operator(
op=torch.ops.aten.conv1d.default,
args=(x, weight, bias),
meta=NodeMetadata(
{"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]}
),
)
builder.output([conv1d])
gm = builder.get_graph_module()

conv1d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv1d.default,
)
self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node")
return gm, conv1d_nodes[0]

def _build_conv1d_relu_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
Expand Down Expand Up @@ -742,6 +774,26 @@ def test_all_quantizers_have_annotation_tests(self) -> None:
f"{untested_names}. Please add test cases or explicitly exclude them."
)

def test_w8a32_mixed_conv1d_annotation(self) -> None:
"""W8A32 conv1d: weight + bias are int8 sym, input/output stay fp32."""
gm, conv_node = self._build_w8a32_conv1d_graph()
CadenceW8A32MixedQuantizer().annotate(gm)

annotation: QuantizationAnnotation = conv_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)
self.assertIsNone(annotation.output_qspec)

weight_node, bias_node = conv_node.args[1], conv_node.args[2]
self.assertEqual(
set(annotation.input_qspec_map.keys()), {weight_node, bias_node}
)
self.assertEqual(
annotation.input_qspec_map[weight_node], qconfig_A32W8sym_127.weight
)
self.assertEqual(
annotation.input_qspec_map[bias_node], qconfig_A32W8sym_127.bias
)


class QuantizerOpsPreserveTest(unittest.TestCase):
def test_mixed_w8a32_ops_to_preserve(self) -> None:
Expand Down
Loading