Skip to content

Commit 77e63ee

Browse files
authored
Qualcomm AI Engine Direct - Addition of new APIs for QNN custom op package and quantization annotation (#19094)
1 parent d0b7934 commit 77e63ee

12 files changed

Lines changed: 1891 additions & 188 deletions

File tree

backends/qualcomm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ backends/qualcomm
4242
| ├── wrappers # Wrapper of QNN data structures for ease of use.
4343
| └── python # Python interface for using QNN libraries.
4444
├── builders # Codes for lowering each operators (AoT Part).
45+
├── custom_op # APIs for using custom ops with QNN backend
4546
├── partition # QNN Partitioner (AoT Part).
4647
├── _passes # Various private passes helping lower models to QNN backend (AoT Part).
4748
├── python # Places to put pybind artifacts for accessing QNN APIs, structures, etc (AoT Part).
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from dataclasses import dataclass
9+
from typing import Callable, Dict, Optional, Union
10+
11+
import torch
12+
from executorch.backends.qualcomm.quantizer.rules import _is_float_tensor
13+
from torchao.quantization.pt2e.quantizer import (
14+
QuantizationAnnotation,
15+
QuantizationSpec,
16+
SharedQuantizationSpec,
17+
)
18+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
@dataclass
24+
class IOQuantConfig:
25+
"""
26+
Quantization config for custom op inputs and outputs.
27+
28+
Attributes:
29+
input_quant_specs: Maps input index to its QuantizationSpec.
30+
Only indices present in the dict are annotated. If None, no inputs
31+
are annotated.
32+
output_quant_specs: Maps output index to its QuantizationSpec.
33+
For single-output ops annotation is done on the op node. For multi-output ops,
34+
each index corresponds to a downstream getitem user. If None, no
35+
outputs are annotated.
36+
"""
37+
38+
input_quant_specs: Optional[
39+
Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]]
40+
] = None
41+
output_quant_specs: Optional[
42+
Dict[int, Union[QuantizationSpec, SharedQuantizationSpec]]
43+
] = None
44+
45+
46+
class CustomOpsQuantAnnotator:
47+
"""
48+
Holds op IOQuantConfigs and builds a single annotation function
49+
compatible with make_quantizer(custom_annotations=...).
50+
"""
51+
52+
def __init__(self):
53+
self._registry: Dict = {} # {op_target: IOQuantConfig}
54+
55+
def register_annotation(
56+
self,
57+
op_target,
58+
io_quant_config: IOQuantConfig,
59+
) -> "CustomOpsQuantAnnotator":
60+
"""
61+
Register quantization config for custom op.
62+
63+
Args:
64+
op_target: The torch op target (e.g. torch.ops.my_ops.custom_op.default).
65+
io_quant_config: IOQuantConfig specifying how to quantize inputs and outputs.
66+
67+
Returns self for method chaining.
68+
"""
69+
self._registry[op_target] = io_quant_config
70+
return self
71+
72+
def build_annotation_fn(self) -> Callable[[torch.fx.GraphModule], None]:
73+
"""
74+
Build and return an annotation function for all registered ops.
75+
76+
The returned function has signature (gm: GraphModule) -> None and
77+
can be passed directly to make_quantizer(custom_annotations=(fn,)).
78+
"""
79+
registry = dict(self._registry)
80+
81+
def annotate_custom_ops(gm: torch.fx.GraphModule) -> None:
82+
for node in gm.graph.nodes:
83+
if node.target not in registry:
84+
continue
85+
86+
cfg = registry[node.target]
87+
input_qspec_map = {}
88+
if cfg.input_quant_specs is not None:
89+
for arg_idx, spec in cfg.input_quant_specs.items():
90+
if arg_idx >= len(node.args):
91+
raise ValueError(
92+
f"IOQuantConfig error for '{node.name}' ({node.target}): "
93+
f"input_quant_specs index {arg_idx} is out of range "
94+
f"(op has {len(node.args)} args)"
95+
)
96+
if not _is_float_tensor(node.args[arg_idx]):
97+
logger.debug(
98+
f"Skipping quantization of input {arg_idx} for "
99+
f"'{node.name}' ({node.target}): expected a float tensor."
100+
)
101+
continue
102+
logger.debug(
103+
f"Annotating input {arg_idx} of '{node.name}' ({node.target}) "
104+
f"with {spec}"
105+
)
106+
input_qspec_map[node.args[arg_idx]] = spec
107+
108+
if not cfg.output_quant_specs or len(cfg.output_quant_specs) <= 1:
109+
# Single output — annotate on the op node
110+
output_spec = (
111+
cfg.output_quant_specs.get(0)
112+
if cfg.output_quant_specs
113+
else None
114+
)
115+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
116+
input_qspec_map=input_qspec_map,
117+
output_qspec=output_spec,
118+
_annotated=True,
119+
)
120+
else:
121+
# Tuple output — push quantization down to getitem users
122+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
123+
input_qspec_map=input_qspec_map,
124+
output_qspec=None,
125+
_annotated=True,
126+
)
127+
for user in node.users:
128+
output_idx = user.args[1]
129+
spec = cfg.output_quant_specs.get(output_idx)
130+
131+
if spec is not None:
132+
user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
133+
output_qspec=spec,
134+
_annotated=True,
135+
)
136+
137+
return annotate_custom_ops
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional
8+
9+
try:
10+
from qti.aisw.op_package_generator.generator import QnnOpPackageGenerator
11+
except ImportError as e:
12+
raise ImportError(
13+
"Failed to import QnnOpPackageGenerator. "
14+
"Please run 'source $QNN_SDK_ROOT/bin/envsetup.sh' to set up the QNN SDK environment."
15+
) from e
16+
17+
from executorch.backends.qualcomm.serialization.qc_schema import (
18+
QnnExecuTorchOpPackageInfo,
19+
QnnExecuTorchOpPackageOptions,
20+
QnnExecuTorchOpPackagePlatform,
21+
QnnExecuTorchOpPackageTarget,
22+
)
23+
24+
25+
class QnnCustomOpPackageBuilder:
26+
"""
27+
Parses a QNN XML op package config and manages registration of
28+
target/platform/implementation for use with ExecuTorch.
29+
30+
Validates that all keys in torch_op_name_map are present in the parsed
31+
package before any implementations are registered.
32+
"""
33+
34+
def __init__(
35+
self,
36+
xml_path: str,
37+
torch_op_name_map,
38+
interface_provider: Optional[str] = None,
39+
):
40+
"""
41+
Args:
42+
xml_path: Path to the QNN XML OpDef config file.
43+
torch_op_name_map: Maps QNN op type names to their corresponding
44+
PyTorch op targets.
45+
e.g. {"ExampleCustomOp": torch.ops.my_ops.custom_op.default}
46+
interface_provider: Interface provider symbol name. Defaults to
47+
"{PackageName}InterfaceProvider" if not specified.
48+
49+
Raises:
50+
ValueError: If any key in torch_op_name_map is not found in the
51+
parsed op package.
52+
"""
53+
op_package_generator = QnnOpPackageGenerator()
54+
op_package_generator.parse_config([xml_path])
55+
56+
pkg_info = op_package_generator.package_infos[0]
57+
self.op_package_name = pkg_info.name
58+
self.interface_provider = (
59+
interface_provider
60+
if interface_provider
61+
else pkg_info.name + "InterfaceProvider"
62+
)
63+
self.torch_op_name_map = torch_op_name_map
64+
self._collection: List[QnnExecuTorchOpPackageInfo] = []
65+
self.operator_names = {op.type_name for op in pkg_info.operators}
66+
67+
missing_ops = set()
68+
for qnn_op in self.torch_op_name_map.keys():
69+
if qnn_op not in self.operator_names:
70+
missing_ops.add(qnn_op)
71+
72+
if len(missing_ops):
73+
raise ValueError(f"Ops missing from OpPackage: {missing_ops}")
74+
75+
def register_implementation(
76+
self,
77+
target: QnnExecuTorchOpPackageTarget,
78+
platform: QnnExecuTorchOpPackagePlatform,
79+
op_package_path: str,
80+
) -> "QnnCustomOpPackageBuilder":
81+
"""
82+
Register one (target, platform, path) combination.
83+
Creates one QnnExecuTorchOpPackageInfo per op in torch_op_name_map.
84+
Returns self for method chaining.
85+
86+
Args:
87+
target: QnnExecuTorchOpPackageTarget
88+
platform: QnnExecuTorchOpPackagePlatform
89+
op_package_path: Path to the implementation for the target/platform.
90+
"""
91+
for qnn_op_type_name, torch_name in self.torch_op_name_map.items():
92+
self._collection.append(
93+
QnnExecuTorchOpPackageInfo(
94+
op_package_name=self.op_package_name,
95+
op_package_path=op_package_path,
96+
interface_provider=self.interface_provider,
97+
target=target,
98+
custom_op_name=str(torch_name),
99+
qnn_op_type_name=qnn_op_type_name,
100+
platform=platform,
101+
)
102+
)
103+
return self
104+
105+
def get_op_package_options(self) -> QnnExecuTorchOpPackageOptions:
106+
"""
107+
Build and return QnnExecuTorchOpPackageOptions from all registered implementations.
108+
Call after all register_implementation() calls are complete.
109+
"""
110+
options = QnnExecuTorchOpPackageOptions()
111+
options.op_package_infos = list(self._collection)
112+
return options

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9204,7 +9204,7 @@ def test_cli_with_input_list_assignment(self):
92049204
golden_output = ep.module()(sample_input, sample_input2)
92059205
self._assert_outputs_equal(golden_output, device_output)
92069206

9207-
def test_custom_op(self):
9207+
def test_custom_op_1(self):
92089208
if not self.required_envs([self.op_package_dir]):
92099209
self.skipTest("missing required envs")
92109210
cmds = [
@@ -9240,6 +9240,42 @@ def test_custom_op(self):
92409240
msg = json.loads(conn.recv())
92419241
self.assertTrue(msg["is_close"])
92429242

9243+
def test_custom_op_2(self):
9244+
if not self.required_envs([self.op_package_dir]):
9245+
self.skipTest("missing required envs")
9246+
cmds = [
9247+
"python",
9248+
f"{self.executorch_root}/examples/qualcomm/custom_op/custom_ops_2.py",
9249+
"--artifact",
9250+
self.artifact_dir,
9251+
"--build_folder",
9252+
self.build_folder,
9253+
"--device",
9254+
self.device,
9255+
"--model",
9256+
self.model,
9257+
"--target",
9258+
self.target,
9259+
"--ip",
9260+
self.ip,
9261+
"--port",
9262+
str(self.port),
9263+
"--op_package_dir",
9264+
self.op_package_dir,
9265+
"--build_op_package",
9266+
]
9267+
if self.host:
9268+
cmds.extend(["--host", self.host])
9269+
if self.enable_x86_64:
9270+
cmds.extend(["--enable_x86_64"])
9271+
9272+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
9273+
with Listener((self.ip, self.port)) as listener:
9274+
conn = listener.accept()
9275+
p.communicate()
9276+
msg = json.loads(conn.recv())
9277+
self.assertTrue(msg["is_close"])
9278+
92439279
def test_debugger_generate_optrace(self):
92449280
cmds = [
92459281
"python",

docs/source/backends-qualcomm.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,18 @@ i.e., the directory containing `QNN_README.txt`.
108108

109109
### Setup environment variables
110110

111-
We set `LD_LIBRARY_PATH` to make sure the dynamic linker can find QNN libraries.
111+
Source the QNN SDK environment setup script to configure paths and environment variables:
112112

113-
Further, we set `PYTHONPATH` because it's easier to develop and import ExecuTorch
114-
Python APIs.
113+
```bash
114+
source $QNN_SDK_ROOT/bin/envsetup.sh
115+
```
116+
117+
This sets up `LD_LIBRARY_PATH` and other required variables for the QNN SDK tools and libraries.
118+
119+
Additionally, set `PYTHONPATH` for ExecuTorch Python APIs:
115120

116121
```bash
117-
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/:$LD_LIBRARY_PATH
118-
export PYTHONPATH=$EXECUTORCH_ROOT/..
122+
export PYTHONPATH=$EXECUTORCH_ROOT/..:$PYTHONPATH
119123
```
120124

121125
## Build
@@ -615,14 +619,13 @@ This matrix directly corresponds to the implementations in: [executorch/backends
615619

616620
### Custom Ops Support
617621

618-
You can extend QNN backend support for your own operators.
619-
Follow the [tutorial](https://github.com/pytorch/executorch/tree/f32cdc3de6f7176d70a80228f1a60bcd45d93437/examples/qualcomm/custom_op#custom-operator-support):
622+
The QNN backend supports custom PyTorch operators with the op package mechanism.
623+
See the [custom op tutorial](https://github.com/pytorch/executorch/tree/main/examples/qualcomm/custom_op) for the full end-to-end flow. It covers:
620624

621-
It covers:
622-
- Writing new NodeVisitor for your op
623-
- Registering via @register_node_visitor
624-
- Creating and linking libQnnOp*.so for the delegate
625-
- Testing and verifying custom kernels on HTP
625+
- Defining a custom PyTorch op (single-output and multi-output)
626+
- Writing and building a QNN op package (XML and Op Implementation)
627+
- Registering the op package with ExecuTorch via `QnnCustomOpPackageBuilder`
628+
- Annotating custom ops for quantization via `CustomOpsQuantAnnotator` / `IOQuantConfig`
626629

627630
## FAQ
628631

0 commit comments

Comments
 (0)