|
28 | 28 | CadenceWithLayerNormQuantizer, |
29 | 29 | CadenceWithSoftmaxQuantizer, |
30 | 30 | qconfig_A16, |
| 31 | + qconfig_A32W8sym_127, |
31 | 32 | qconfig_A8W8, |
32 | 33 | qconfig_A8W8sym, |
33 | 34 | ) |
|
57 | 58 | # These should be explicitly justified when added. |
58 | 59 | EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { |
59 | 60 | CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything |
60 | | - CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage |
61 | 61 | CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition |
62 | 62 | } |
63 | 63 |
|
|
248 | 248 | # This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests. |
249 | 249 | TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = { |
250 | 250 | type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES |
251 | | -} |
| 251 | +} | {CadenceW8A32MixedQuantizer} # tested via dedicated methods |
252 | 252 |
|
253 | 253 |
|
254 | 254 | class QuantizerAnnotationTest(unittest.TestCase): |
@@ -607,6 +607,38 @@ def _build_conv2d_relu_graph( |
607 | 607 |
|
608 | 608 | return gm, relu_nodes[0], conv2d_nodes[0] |
609 | 609 |
|
| 610 | + def _build_w8a32_conv1d_graph( |
| 611 | + self, |
| 612 | + ) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 613 | + """Build a graph with conv1d(input, weight, bias) for the W8A32 pattern. |
| 614 | +
|
| 615 | + The MixedW8A32ConvPattern requires exactly 3 args (input, weight, bias), |
| 616 | + no kwargs, weight shape with channels multiple of 4 and kernel_size==3, |
| 617 | + and input length equal to the kernel size (3). |
| 618 | + """ |
| 619 | + builder = GraphBuilder() |
| 620 | + # Input shape: (batch, in_channels, length) = (1, 4, 3) — channels and length match constraints. |
| 621 | + x = builder.placeholder("x", torch.randn(1, 4, 3)) |
| 622 | + # Weight shape: (out_channels, in_channels, kernel_size) = (8, 4, 3). |
| 623 | + weight = builder.placeholder("weight", torch.randn(8, 4, 3)) |
| 624 | + bias = builder.placeholder("bias", torch.randn(8)) |
| 625 | + conv1d = builder.call_operator( |
| 626 | + op=torch.ops.aten.conv1d.default, |
| 627 | + args=(x, weight, bias), |
| 628 | + meta=NodeMetadata( |
| 629 | + {"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]} |
| 630 | + ), |
| 631 | + ) |
| 632 | + builder.output([conv1d]) |
| 633 | + gm = builder.get_graph_module() |
| 634 | + |
| 635 | + conv1d_nodes = gm.graph.find_nodes( |
| 636 | + op="call_function", |
| 637 | + target=torch.ops.aten.conv1d.default, |
| 638 | + ) |
| 639 | + self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node") |
| 640 | + return gm, conv1d_nodes[0] |
| 641 | + |
610 | 642 | def _build_conv1d_relu_graph( |
611 | 643 | self, |
612 | 644 | ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: |
@@ -742,6 +774,26 @@ def test_all_quantizers_have_annotation_tests(self) -> None: |
742 | 774 | f"{untested_names}. Please add test cases or explicitly exclude them." |
743 | 775 | ) |
744 | 776 |
|
| 777 | + def test_w8a32_mixed_conv1d_annotation(self) -> None: |
| 778 | + """W8A32 conv1d: weight + bias are int8 sym, input/output stay fp32.""" |
| 779 | + gm, conv_node = self._build_w8a32_conv1d_graph() |
| 780 | + CadenceW8A32MixedQuantizer().annotate(gm) |
| 781 | + |
| 782 | + annotation: QuantizationAnnotation = conv_node.meta[Q_ANNOTATION_KEY] |
| 783 | + self.assertTrue(annotation._annotated) |
| 784 | + self.assertIsNone(annotation.output_qspec) |
| 785 | + |
| 786 | + weight_node, bias_node = conv_node.args[1], conv_node.args[2] |
| 787 | + self.assertEqual( |
| 788 | + set(annotation.input_qspec_map.keys()), {weight_node, bias_node} |
| 789 | + ) |
| 790 | + self.assertEqual( |
| 791 | + annotation.input_qspec_map[weight_node], qconfig_A32W8sym_127.weight |
| 792 | + ) |
| 793 | + self.assertEqual( |
| 794 | + annotation.input_qspec_map[bias_node], qconfig_A32W8sym_127.bias |
| 795 | + ) |
| 796 | + |
745 | 797 |
|
746 | 798 | class QuantizerOpsPreserveTest(unittest.TestCase): |
747 | 799 | def test_mixed_w8a32_ops_to_preserve(self) -> None: |
|
0 commit comments