Skip to content

Commit f976e63

Browse files
authored
Fix constant_pad_nd->cat lowering dtype for quantized graphs
Differential Revision: D107545428 Pull Request resolved: pytorch#20039
1 parent a9c89f3 commit f976e63

2 files changed

Lines changed: 27 additions & 2 deletions

File tree

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
632632
value = 0 if len(node.args) == 2 else node.args[2]
633633

634634
arg_shape = input_node.meta["val"].shape
635+
dtype = input_node.meta["val"].dtype
635636

636637
# Convert orig_padding to a list for manipulation
637638
# pyre-ignore[6]: Argument type
@@ -663,7 +664,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
663664
left_padding_shape,
664665
value,
665666
),
666-
kwargs={"dtype": torch.float32},
667+
kwargs={"dtype": dtype},
667668
)
668669
left_padding_node.meta = node.meta
669670
cat_tensors.append(left_padding_node)
@@ -683,7 +684,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
683684
right_padding_shape,
684685
value,
685686
),
686-
kwargs={"dtype": torch.float32},
687+
kwargs={"dtype": dtype},
687688
)
688689
right_padding_node.meta = node.meta
689690
cat_tensors.append(right_padding_node)

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,30 @@ def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]) -> N
839839
0,
840840
)
841841

842+
@torch.no_grad()
843+
def test_replace_pad_with_cat_preserves_dtype(self) -> None:
844+
# The padding constant tensors must match the input dtype, otherwise the
845+
# resulting cat mixes dtypes and fails edge dialect dtype verification
846+
# (e.g. for quantized int8 graphs).
847+
x = torch.randint(-128, 127, (1, 2, 3), dtype=torch.int8)
848+
original_gm = single_op_builder(
849+
placeholders=(x,),
850+
op=exir_ops.edge.aten.constant_pad_nd.default,
851+
args=(x, [1, 1]),
852+
)
853+
854+
p = ReplacePadWithCatPass()
855+
result = cast(PassResult, p(original_gm))
856+
self.assertTrue(result.modified)
857+
graph_after_passes = result.graph_module
858+
859+
full_nodes = graph_after_passes.graph.find_nodes(
860+
op="call_function", target=exir_ops.edge.aten.full.default
861+
)
862+
self.assertEqual(len(full_nodes), 2)
863+
for full_node in full_nodes:
864+
self.assertEqual(full_node.kwargs["dtype"], torch.int8)
865+
842866
@torch.no_grad()
843867
def test_replace_repeat_with_cat(self) -> None:
844868
x = torch.randn([3, 5])

0 commit comments

Comments
 (0)