Skip to content

Commit f35fbb5

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for the remainder core ATen ops (pytorch#18843)
1 parent eaef2ed commit f35fbb5

6 files changed

Lines changed: 315 additions & 68 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .decompose_maxpool3d import DecomposeMaxPool3d
2828
from .decompose_minmaxdim import DecomposeMinMaxDim
2929
from .decompose_reciprocal import DecomposeReciprocal
30+
from .decompose_remainder import DecomposeRemainder
3031
from .decompose_roll import DecomposeRoll
3132
from .decompose_silu import DecomposeSilu
3233
from .decompose_threshold import DecomposeThreshold
@@ -80,6 +81,7 @@
8081
DecomposeMaxPool3d,
8182
DecomposeMinMaxDim,
8283
DecomposeReciprocal,
84+
DecomposeRemainder,
8385
DecomposeRoll,
8486
DecomposeSilu,
8587
DecomposeThreshold,
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
12+
13+
from .utils import copy_meta, get_const_node
14+
15+
16+
class DecomposeRemainder(ExportPass):
17+
"""
18+
Decompose remainder.Scalar and remainder.Tensor using the identity:
19+
remainder(x, y) = x - floor(x / y) * y
20+
"""
21+
22+
def __init__(self):
23+
super(DecomposeRemainder, self).__init__()
24+
self.remainder_targets = {
25+
torch.ops.aten.remainder.Scalar,
26+
torch.ops.aten.remainder.Tensor,
27+
exir_ops.edge.aten.remainder.Scalar,
28+
exir_ops.edge.aten.remainder.Tensor,
29+
}
30+
31+
def call(self, graph_module: torch.fx.GraphModule):
32+
graph = graph_module.graph
33+
# Cache scalar:node mappings to avoid duplicate buffer registrations if the same scalar divisor appears in multiple remainder ops
34+
const_cache = {}
35+
36+
for node in list(graph.nodes):
37+
if node.op == "call_function" and node.target in self.remainder_targets:
38+
x_node = node.args[0]
39+
y_arg = node.args[1]
40+
is_edge = isinstance(node.target, EdgeOpOverload)
41+
meta = node.meta
42+
43+
div_op = (
44+
exir_ops.edge.aten.div.Tensor
45+
if is_edge
46+
else torch.ops.aten.div.Tensor
47+
)
48+
floor_op = (
49+
exir_ops.edge.aten.floor.default
50+
if is_edge
51+
else torch.ops.aten.floor.default
52+
)
53+
mul_op = (
54+
exir_ops.edge.aten.mul.Tensor
55+
if is_edge
56+
else torch.ops.aten.mul.Tensor
57+
)
58+
sub_op = (
59+
exir_ops.edge.aten.sub.Tensor
60+
if is_edge
61+
else torch.ops.aten.sub.Tensor
62+
)
63+
64+
is_scalar = not isinstance(y_arg, torch.fx.Node)
65+
if is_scalar and is_edge:
66+
if y_arg not in const_cache:
67+
attr_name = get_new_attr_name_with_prefix("_remainder_const_")(
68+
graph_module
69+
)
70+
const_cache[y_arg] = get_const_node(
71+
graph, graph_module, attr_name, y_arg, node
72+
)
73+
y_node = const_cache[y_arg]
74+
else:
75+
y_node = y_arg
76+
77+
with graph.inserting_before(node):
78+
div_node = graph.create_node(
79+
"call_function", div_op, (x_node, y_node)
80+
)
81+
div_node.meta = copy_meta(meta)
82+
83+
floor_node = graph.create_node(
84+
"call_function", floor_op, (div_node,)
85+
)
86+
floor_node.meta = copy_meta(meta)
87+
88+
mul_node = graph.create_node(
89+
"call_function", mul_op, (floor_node, y_node)
90+
)
91+
mul_node.meta = copy_meta(meta)
92+
93+
sub_node = graph.create_node(
94+
"call_function", sub_op, (x_node, mul_node)
95+
)
96+
sub_node.meta = copy_meta(meta)
97+
98+
for user in node.users.copy():
99+
user.replace_input_with(node, sub_node)
100+
101+
graph.eliminate_dead_code()
102+
graph_module.recompile()
103+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeMaxPool3d,
3333
DecomposeMinMaxDim,
3434
DecomposeReciprocal,
35+
DecomposeRemainder,
3536
DecomposeRoll,
3637
DecomposeSilu,
3738
DecomposeThreshold,
@@ -106,6 +107,7 @@ def get_capture_program_passes():
106107
(DecomposeLogVariants, True),
107108
(DecomposeMaxPool3d, True),
108109
(DecomposeMinMaxDim, True),
110+
(DecomposeRemainder, True),
109111
(DecomposeTrunc, True),
110112
(ExpandBroadcastTensorShape, True),
111113
(FixedLinearKeepDim, True),
@@ -239,6 +241,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
239241
# Decompose Reciprocal into Div for these 2 backend
240242
# TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager)
241243
self.add_pass(DecomposeReciprocal())
244+
self.add_pass(DecomposeRemainder())
242245
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
243246
self.add_pass(DecomposeLogVariants())
244247
self.add_pass(ReplaceInfValues())

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program():
7171
DecomposeLinalgVectorNorm,
7272
DecomposeLogVariants,
7373
DecomposeMaxPool3d,
74+
DecomposeRemainder,
7475
DecomposeTrunc,
7576
ExpandBroadcastTensorShape,
7677
FixedLinearKeepDim,
@@ -101,6 +102,7 @@ def get_passes_dependency_for_capture_program():
101102
DecomposeLinalgVectorNorm: [RemoveRedundancy],
102103
DecomposeLogVariants: [RemoveRedundancy],
103104
DecomposeMaxPool3d: [RemoveRedundancy],
105+
DecomposeRemainder: [RemoveRedundancy],
104106
DecomposeTrunc: [RemoveRedundancy],
105107
ExpandBroadcastTensorShape: [FoldQDQ],
106108
FixedLinearKeepDim: [FoldQDQ],

backends/qualcomm/tests/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,30 @@ def forward(self, x):
18891889
return x.repeat(1, 2, 3, 4)
18901890

18911891

1892+
class RemainderScalar(torch.nn.Module):
1893+
def __init__(self):
1894+
super().__init__()
1895+
1896+
def forward(self, x):
1897+
return torch.remainder(x, 3.0)
1898+
1899+
1900+
class RemainderTensor(torch.nn.Module):
1901+
def __init__(self):
1902+
super().__init__()
1903+
1904+
def forward(self, x, y):
1905+
return torch.remainder(x, y)
1906+
1907+
1908+
class RemainderMultiNode(torch.nn.Module):
1909+
def __init__(self):
1910+
super().__init__()
1911+
1912+
def forward(self, x, y):
1913+
return torch.remainder(x, 3.0), torch.remainder(x, y)
1914+
1915+
18921916
class ReWriteObs(torch.nn.Module):
18931917
def __init__(self):
18941918
super().__init__()

0 commit comments

Comments
 (0)