Skip to content

Commit 1317bf1

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Fix quantized max_pool2d output observer for tuple-returning ops (#19259)
Summary: prepare_pt2e skips the output observer for max_pool2d_with_indices because its output is a tuple, not a single tensor. This caused the quantized_max_pool2d fusion to silently fail. Fix: annotate getitem[0] (single tensor) instead of the tuple-returning op. Also bail out if indices are consumed or the graph structure is unexpected. Reviewed By: mcremon-meta Differential Revision: D103436172
1 parent 8464b47 commit 1317bf1

2 files changed

Lines changed: 38 additions & 12 deletions

File tree

backends/cadence/aot/quantizer/patterns.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import operator
910
from abc import ABC, abstractmethod
1011
from dataclasses import dataclass, field
1112
from typing import List, Tuple, Union
@@ -493,7 +494,20 @@ def get_anchors(
493494
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
494495
max_pool_node = fused_partition[0].nodes[-1]
495496

496-
# Input and output share quantization parameters since max is order-preserving
497+
# Since max_pool2d_with_indices returns a tuple, the output observer must be
498+
# placed on getitem[0] rather than the tuple-returning op. Otherwise
499+
# prepare_pt2e silently skips it.
500+
# Expect exactly one user: getitem[0] extracting the values tensor. If indices
501+
# are also used or the structure is unexpected, bail out.
502+
users = list(max_pool_node.users)
503+
if (
504+
len(users) != 1
505+
or users[0].target is not operator.getitem
506+
or users[0].args[1] != 0
507+
):
508+
return PartitionAnchors(empty=True), max_pool_node
509+
getitem_0 = users[0]
510+
497511
return (
498512
PartitionAnchors(
499513
inputs=[(max_pool_node, 0)],
@@ -505,7 +519,7 @@ def get_anchors(
505519
],
506520
output=[
507521
(
508-
max_pool_node,
522+
getitem_0,
509523
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
510524
)
511525
],

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import inspect
10+
import operator
1011
import unittest
1112
from typing import Callable
1213

@@ -483,12 +484,18 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
483484
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
484485
return gm, addmm_nodes[0]
485486

486-
def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
487-
"""Build a simple graph with a max_pool2d_with_indices operation."""
487+
def _build_max_pool2d_graph(
488+
self,
489+
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
490+
"""Build a graph with max_pool2d_with_indices followed by getitem[0].
491+
492+
Returns:
493+
A tuple of (graph_module, getitem_node, max_pool_node).
494+
The getitem_node is where the output annotation is placed.
495+
The max_pool_node is where the input annotation is placed.
496+
"""
488497
builder = GraphBuilder()
489-
# Input shape: (batch, channels, height, width)
490498
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
491-
# max_pool2d_with_indices args: (input, kernel_size, stride, padding, dilation, ceil_mode)
492499
max_pool = builder.call_operator(
493500
op=torch.ops.aten.max_pool2d_with_indices.default,
494501
args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False),
@@ -503,19 +510,24 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
503510
}
504511
),
505512
)
506-
builder.output([max_pool])
513+
getitem = builder.call_operator(
514+
op=operator.getitem,
515+
args=(max_pool, 0),
516+
)
517+
builder.output([getitem])
507518
gm = builder.get_graph_module()
508519

509520
max_pool_nodes = gm.graph.find_nodes(
510521
op="call_function",
511522
target=torch.ops.aten.max_pool2d_with_indices.default,
512523
)
513-
self.assertEqual(
514-
len(max_pool_nodes),
515-
1,
516-
"Should find exactly one max_pool2d_with_indices node",
524+
self.assertEqual(len(max_pool_nodes), 1)
525+
getitem_nodes = gm.graph.find_nodes(
526+
op="call_function",
527+
target=operator.getitem,
517528
)
518-
return gm, max_pool_nodes[0]
529+
self.assertEqual(len(getitem_nodes), 1)
530+
return gm, getitem_nodes[0], max_pool_nodes[0]
519531

520532
def _build_add_relu_graph(
521533
self,

0 commit comments

Comments
 (0)