Skip to content

Commit 4d698cb

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for fill.scalar core ATen op (pytorch#19826)
### Summary Added support for the `fill.scalar` op via a decomposition pass using the `full` op and the identity: ``` fill(input, value) = full(input.shape, value) ``` ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_fill --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_fill --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android ```
1 parent ac3003e commit 4d698cb

8 files changed

Lines changed: 91 additions & 1 deletion

File tree

.claude/skills/qualcomm/new_op_development.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class DecomposeMyOp(ExportPass):
210210
return PassResult(graph_module, True)
211211
```
212212

213-
**Critical rules:** (1) handle both dialects via `EdgeOpOverload` check, (2) `copy_meta` on every new node, (3) lift scalars to tensors in edge dialect with `get_const_node`, (4) cache constants with `const_cache`, (5) for bool-output nodes use `callback=lambda m: {**m, "val": m["val"].to(torch.bool)}` in `create_node`.
213+
**Critical rules:** (1) handle both dialects via `EdgeOpOverload` check, (2) `copy_meta` on every new node, (3) lift scalars to tensors in edge dialect with `get_const_node`, (4) cache constants with `const_cache`, (5) for bool-output nodes use `callback=lambda m: {**m, "val": m["val"].to(torch.bool)}` in `create_node`, (6) **never pass kwargs** (like `dtype`/`device`) to `graph.create_node` for ATen ops — the ATen IR requires kwargs to be empty (`prepare_pt2e` asserts this); instead rely on `copy_meta` which propagates dtype/device via the FakeTensor in `node.meta["val"]`.
214214

215215
### Approach C: Built-in Decomposition Table
216216
**Ref:** `_passes/decompose_triu.py`. Uses `make_fx` + `get_decompositions`. Only works if PyTorch has a registered decomp.

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .decompose_col_im import DecomposeColIm
2222
from .decompose_einsum import DecomposeEinsum
2323
from .decompose_expm1 import DecomposeExpM1
24+
from .decompose_fill import DecomposeFill
2425
from .decompose_floor_divide import DecomposeFloorDivide
2526
from .decompose_glu import DecomposeGlu
2627
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
@@ -80,6 +81,7 @@
8081
DecomposeColIm,
8182
DecomposeEinsum,
8283
DecomposeExpM1,
84+
DecomposeFill,
8385
DecomposeFloorDivide,
8486
DecomposeGlu,
8587
DecomposeLinalgVectorNorm,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from executorch.exir.passes import dead_code_elimination_pass
12+
13+
from .utils import copy_meta
14+
15+
16+
class DecomposeFill(ExportPass):
17+
"""
18+
Decompose fill.Scalar into full.default.
19+
fill(input, value) is semantically equivalent to full(input.shape, value).
20+
"""
21+
22+
def __init__(self):
23+
super().__init__()
24+
self.targets = {
25+
torch.ops.aten.fill.Scalar,
26+
torch.ops.aten.fill_.Scalar,
27+
exir_ops.edge.aten.fill.Scalar,
28+
exir_ops.edge.aten.fill_.Scalar,
29+
}
30+
31+
def call(self, graph_module: torch.fx.GraphModule):
32+
graph = graph_module.graph
33+
for node in list(graph.nodes):
34+
if node.op == "call_function" and node.target in self.targets:
35+
fill_node = node
36+
is_edge = isinstance(node.target, EdgeOpOverload)
37+
input_node = node.args[0]
38+
scalar_value = node.args[1]
39+
40+
# Get the shape from the input tensor metadata
41+
shape = list(input_node.meta["val"].shape)
42+
43+
full_op = (
44+
exir_ops.edge.aten.full.default
45+
if is_edge
46+
else torch.ops.aten.full.default
47+
)
48+
49+
with graph.inserting_after(input_node):
50+
full_node = graph.create_node(
51+
"call_function",
52+
full_op,
53+
(shape, scalar_value),
54+
)
55+
full_node.meta = copy_meta(fill_node.meta)
56+
57+
for user in fill_node.users.copy():
58+
user.replace_input_with(fill_node, full_node)
59+
60+
dead_code_elimination_pass(graph_module)
61+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DecomposeColIm,
2727
DecomposeEinsum,
2828
DecomposeExpM1,
29+
DecomposeFill,
2930
DecomposeFloorDivide,
3031
DecomposeGlu,
3132
DecomposeLinalgVectorNorm,
@@ -110,6 +111,7 @@ def get_capture_program_passes():
110111
(DecomposeAny, True),
111112
(DecomposeAtan2, True),
112113
(DecomposeColIm, True),
114+
(DecomposeFill, True),
113115
(DecomposeLogVariants, True),
114116
(DecomposeMaxPool3d, True),
115117
(DecomposeMinMaxDim, True),
@@ -248,6 +250,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
248250
self.add_pass(DecomposeWrapWithAutocast())
249251
self.add_pass(DecomposeEinsum())
250252
self.add_pass(DecomposeExpM1())
253+
self.add_pass(DecomposeFill())
251254
self.add_pass(DecomposeGlu())
252255
# HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal
253256
# Decompose Reciprocal into Div for these 2 backend
@@ -275,6 +278,7 @@ def transform_for_export_pipeline(
275278
self.add_pass(DecomposeTriu())
276279
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
277280
self.add_pass(DecomposeExpM1())
281+
self.add_pass(DecomposeFill())
278282
# DecomposeFloorDivide does not apply to the annotation pipeline,
279283
# since the CPU QDQ model would reduce accuracy.
280284
# We keep div and floor operations in floating-point to maintain precision.

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def get_passes_dependency_for_capture_program():
6969
DecomposeAny,
7070
DecomposeAtan2,
7171
DecomposeColIm,
72+
DecomposeFill,
7273
DecomposeLinalgVectorNorm,
7374
DecomposeLogVariants,
7475
DecomposeMaxPool3d,
@@ -104,6 +105,7 @@ def get_passes_dependency_for_capture_program():
104105
DecomposeAny: [RemoveRedundancy],
105106
DecomposeAtan2: [RemoveRedundancy],
106107
DecomposeColIm: [FoldQDQ],
108+
DecomposeFill: [RemoveRedundancy],
107109
DecomposeLinalgVectorNorm: [RemoveRedundancy],
108110
DecomposeLogVariants: [RemoveRedundancy],
109111
DecomposeMaxPool3d: [RemoveRedundancy],

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ The following PyTorch operators are supported through decomposition or annotatio
506506
| `aten.im2col`, `aten.col2im` | `DecomposeColIm` |
507507
| `aten.einsum` | `DecomposeEinsum` |
508508
| `aten.special_expm1` | `DecomposeExpM1` |
509+
| `aten.fill.Scalar` | `DecomposeFill` |
509510
| `aten.floor_divide` | `DecomposeFloorDivide` |
510511
| `aten.glu` | `DecomposeGlu` |
511512
| `aten.linalg_vector_norm` | `DecomposeLinalgVectorNorm` |

backends/qualcomm/tests/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,15 @@ def forward(self, x):
11151115
return torch.special.expm1(x)
11161116

11171117

1118+
class Fill(torch.nn.Module):
1119+
def __init__(self, value):
1120+
super().__init__()
1121+
self.value = value
1122+
1123+
def forward(self, x):
1124+
return torch.add(x, torch.fill(x, self.value))
1125+
1126+
11181127
class Flip(torch.nn.Module):
11191128
def __init__(self):
11201129
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,11 @@ def test_qnn_backend_fp16a8w_fp16_simple_model(self):
965965
)
966966
self.lower_module_and_test_output(module, sample_input)
967967

968+
def test_qnn_backend_fill(self):
969+
module = Fill(3.14) # noqa: F405
970+
sample_input = (torch.randn(1, 2, 3, 4),)
971+
self.lower_module_and_test_output(module, sample_input)
972+
968973
def test_qnn_backend_flip(self):
969974
sample_input = (torch.randn(3, 4, 5, 6),)
970975
module = Flip() # noqa: F405
@@ -3586,6 +3591,12 @@ def test_qnn_backend_expm1(self):
35863591
module = self.get_qdq_module(module, sample_input)
35873592
self.lower_module_and_test_output(module, sample_input)
35883593

3594+
def test_qnn_backend_fill(self):
3595+
module = Fill(3.14) # noqa: F405
3596+
sample_input = (torch.randn(1, 2, 3, 4),)
3597+
module = self.get_qdq_module(module, sample_input)
3598+
self.lower_module_and_test_output(module, sample_input)
3599+
35893600
def test_qnn_backend_flip(self):
35903601
sample_input = (torch.randn(3, 4, 5, 6),)
35913602
module = Flip() # noqa: F405

0 commit comments

Comments
 (0)