Skip to content

Commit d69567f

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Fix quantized max_pool2d output observer for tuple-returning ops
Summary: prepare_pt2e skips the output observer for max_pool2d_with_indices because its output is a tuple, not a single tensor. This caused the quantized_max_pool2d fusion to silently fail. Fix: annotate getitem[0] (single tensor) instead of the tuple-returning op. Also bail out if indices are consumed or the graph structure is unexpected. Differential Revision: D103436172
1 parent a7e44bf commit d69567f

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

backends/cadence/aot/quantizer/patterns.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9+
import operator
910
from abc import ABC, abstractmethod
1011
from dataclasses import dataclass, field
11-
from typing import List, Tuple, Union
12+
from typing import List, Optional, Tuple, Union
1213

1314
import torch
1415
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
@@ -489,11 +490,32 @@ def partition_types(self) -> List[OpOverload]:
489490

490491
def get_anchors(
491492
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
492-
) -> Tuple[PartitionAnchors, fx.Node]:
493+
) -> Optional[Tuple[PartitionAnchors, fx.Node]]:
493494
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
494495
max_pool_node = fused_partition[0].nodes[-1]
495496

496-
# Input and output share quantization parameters since max is order-preserving
497+
# Find getitem[0] (values) node. Since max_pool2d_with_indices returns a tuple,
498+
# the output observer must be placed on getitem[0] (a single tensor) rather than
499+
# the tuple-returning op itself - otherwise prepare_pt2e silently skips the
500+
# output observer.
501+
getitem_0 = None
502+
for user in max_pool_node.users:
503+
if user.target is not operator.getitem:
504+
# Unexpected consumer of tuple output - skip quantization
505+
return None
506+
if user.args[1] == 0:
507+
if getitem_0 is not None:
508+
# Multiple getitem[0] nodes - unexpected graph structure
509+
return None
510+
getitem_0 = user
511+
elif user.args[1] == 1 and len(user.users) > 0:
512+
# Indices are consumed downstream - can't quantize this op
513+
return None
514+
515+
if getitem_0 is None:
516+
# No getitem[0] found - values output is unused
517+
return None
518+
497519
return (
498520
PartitionAnchors(
499521
inputs=[(max_pool_node, 0)],
@@ -505,7 +527,7 @@ def get_anchors(
505527
],
506528
output=[
507529
(
508-
max_pool_node,
530+
getitem_0,
509531
SharedQuantizationSpec((max_pool_node.args[0], max_pool_node)),
510532
)
511533
],

0 commit comments

Comments
 (0)