Skip to content

Commit 08c7ef6

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Fix quantized max_pool2d output observer for tuple-returning ops (#19259)
Summary: prepare_pt2e skips the output observer for max_pool2d_with_indices because its output is a tuple, not a single tensor. This caused the quantized_max_pool2d fusion to silently fail. Fix: annotate getitem[0] (single tensor) instead of the tuple-returning op. Also bail out if indices are consumed or the graph structure is unexpected. Reviewed By: mcremon-meta Differential Revision: D103436172
1 parent a7e44bf commit 08c7ef6

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

backends/cadence/aot/quantizer/patterns.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import operator
910
from abc import ABC, abstractmethod
1011
from dataclasses import dataclass, field
1112
from typing import List, Tuple, Union
@@ -493,7 +494,16 @@ def get_anchors(
493494
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
494495
max_pool_node = fused_partition[0].nodes[-1]
495496

496-
# Input and output share quantization parameters since max is order-preserving
497+
# Since max_pool2d_with_indices returns a tuple, the output observer must be
498+
# placed on getitem[0] rather than the tuple-returning op. Otherwise
499+
# prepare_pt2e silently skips it.
500+
# Expect exactly one user: getitem[0] extracting the values tensor. If indices
501+
# are also used or the structure is unexpected, bail out.
502+
users = list(max_pool_node.users)
503+
if len(users) != 1 or users[0].target is not operator.getitem:
504+
return PartitionAnchors(empty=True), max_pool_node
505+
getitem_0 = users[0]
506+
497507
return (
498508
PartitionAnchors(
499509
inputs=[(max_pool_node, 0)],
@@ -505,7 +515,7 @@ def get_anchors(
505515
],
506516
output=[
507517
(
508-
max_pool_node,
518+
getitem_0,
509519
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
510520
)
511521
],

0 commit comments

Comments
 (0)