Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import operator
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Tuple, Union
Expand Down Expand Up @@ -493,7 +494,20 @@ def get_anchors(
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
max_pool_node = fused_partition[0].nodes[-1]

# Input and output share quantization parameters since max is order-preserving
# Since max_pool2d_with_indices returns a tuple, the output observer must be
# placed on getitem[0] rather than the tuple-returning op. Otherwise
# prepare_pt2e silently skips it.
# Expect exactly one user: getitem[0] extracting the values tensor. If indices
# are also used or the structure is unexpected, bail out.
users = list(max_pool_node.users)
if (
len(users) != 1
or users[0].target is not operator.getitem
or users[0].args[1] != 0
):
return PartitionAnchors(empty=True), max_pool_node
getitem_0 = users[0]

return (
PartitionAnchors(
inputs=[(max_pool_node, 0)],
Expand All @@ -505,7 +519,7 @@ def get_anchors(
],
output=[
(
max_pool_node,
getitem_0,
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
)
],
Expand Down
32 changes: 22 additions & 10 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import inspect
import operator
import unittest
from typing import Callable

Expand Down Expand Up @@ -483,12 +484,18 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
return gm, addmm_nodes[0]

def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a max_pool2d_with_indices operation."""
def _build_max_pool2d_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
"""Build a graph with max_pool2d_with_indices followed by getitem[0].

Returns:
A tuple of (graph_module, getitem_node, max_pool_node).
The getitem_node is where the output annotation is placed.
The max_pool_node is where the input annotation is placed.
"""
builder = GraphBuilder()
# Input shape: (batch, channels, height, width)
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
# max_pool2d_with_indices args: (input, kernel_size, stride, padding, dilation, ceil_mode)
max_pool = builder.call_operator(
op=torch.ops.aten.max_pool2d_with_indices.default,
args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False),
Expand All @@ -503,19 +510,24 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
}
),
)
builder.output([max_pool])
getitem = builder.call_operator(
op=operator.getitem,
args=(max_pool, 0),
)
builder.output([getitem])
gm = builder.get_graph_module()

max_pool_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.max_pool2d_with_indices.default,
)
self.assertEqual(
len(max_pool_nodes),
1,
"Should find exactly one max_pool2d_with_indices node",
self.assertEqual(len(max_pool_nodes), 1)
getitem_nodes = gm.graph.find_nodes(
op="call_function",
target=operator.getitem,
)
return gm, max_pool_nodes[0]
self.assertEqual(len(getitem_nodes), 1)
return gm, getitem_nodes[0], max_pool_nodes[0]

def _build_add_relu_graph(
self,
Expand Down
Loading