diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 07aad18e36a..54c01227d07 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -6,6 +6,7 @@ # pyre-strict +import operator from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List, Tuple, Union @@ -493,7 +494,20 @@ def get_anchors( # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... max_pool_node = fused_partition[0].nodes[-1] - # Input and output share quantization parameters since max is order-preserving + # Since max_pool2d_with_indices returns a tuple, the output observer must be + # placed on getitem[0] rather than the tuple-returning op. Otherwise + # prepare_pt2e silently skips it. + # Expect exactly one user: getitem[0] extracting the values tensor. If indices + # are also used or the structure is unexpected, bail out. + users = list(max_pool_node.users) + if ( + len(users) != 1 + or users[0].target is not operator.getitem + or users[0].args[1] != 0 + ): + return PartitionAnchors(empty=True), max_pool_node + getitem_0 = users[0] + return ( PartitionAnchors( inputs=[(max_pool_node, 0)], @@ -505,7 +519,7 @@ def get_anchors( ], output=[ ( - max_pool_node, + getitem_0, SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)), ) ], diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index dde26f06b7b..f5598a8bd4f 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -7,6 +7,7 @@ # pyre-strict import inspect +import operator import unittest from typing import Callable @@ -483,12 +484,18 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node") return gm, addmm_nodes[0] - def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: - """Build a simple graph with a max_pool2d_with_indices operation.""" + def _build_max_pool2d_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with max_pool2d_with_indices followed by getitem[0]. + + Returns: + A tuple of (graph_module, getitem_node, max_pool_node). + The getitem_node is where the output annotation is placed. + The max_pool_node is where the input annotation is placed. + """ builder = GraphBuilder() - # Input shape: (batch, channels, height, width) x = builder.placeholder("x", torch.randn(1, 3, 8, 8)) - # max_pool2d_with_indices args: (input, kernel_size, stride, padding, dilation, ceil_mode) max_pool = builder.call_operator( op=torch.ops.aten.max_pool2d_with_indices.default, args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), @@ -503,19 +510,24 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: } ), ) - builder.output([max_pool]) + getitem = builder.call_operator( + op=operator.getitem, + args=(max_pool, 0), + ) + builder.output([getitem]) gm = builder.get_graph_module() max_pool_nodes = gm.graph.find_nodes( op="call_function", target=torch.ops.aten.max_pool2d_with_indices.default, ) - self.assertEqual( - len(max_pool_nodes), - 1, - "Should find exactly one max_pool2d_with_indices node", + self.assertEqual(len(max_pool_nodes), 1) + getitem_nodes = gm.graph.find_nodes( + op="call_function", + target=operator.getitem, ) - return gm, max_pool_nodes[0] + self.assertEqual(len(getitem_nodes), 1) + return gm, getitem_nodes[0], max_pool_nodes[0] def _build_add_relu_graph( self,