Skip to content

Commit bd99b2a

Browse files
Fix InsertIOQDQ KeyError for dequantize encodings (#18622) (#18622)
1 parent caf086c commit bd99b2a

2 files changed

Lines changed: 93 additions & 4 deletions

File tree

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@ class InsertIOQDQ(ExportPass):
3131
"""
3232

3333
q_dq_map = {
34-
# per tensor
34+
# per tensor (quantize -> dequantize)
3535
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
3636
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
37-
# per channel
37+
# per tensor (dequantize -> dequantize, for pre-quantized params)
38+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
39+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
40+
# per channel (quantize -> dequantize)
3841
exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
42+
# per channel (dequantize -> dequantize, for pre-quantized params)
43+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
3944
}
4045

4146
def __init__(self, edge_program: torch.export.ExportedProgram):
@@ -80,7 +85,7 @@ def _create_node(
8085
(node, *self._ceate_args(target, quant_attrs)),
8186
)
8287
meta_val = node.meta["val"]
83-
if target in self.q_dq_map:
88+
if target in q_ops:
8489
inserted_node.meta[QCOM_QUANT_ATTRS] = node.meta.pop(QCOM_QUANT_ATTRS)
8590
meta_val = meta_val.to(quant_attrs["dtype"])
8691

@@ -118,7 +123,9 @@ def _insert_dequant_node(
118123
user.replace_input_with(node, inserted_node)
119124

120125
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
121-
for n in graph_module.graph.nodes:
126+
# Snapshot nodes: inserting Q/DQ nodes mutates the graph's linked list,
127+
# so iterating the live list can revisit newly inserted nodes.
128+
for n in list(graph_module.graph.nodes):
122129
# do nothing when a node is expected to output a quant tensor
123130
if n.meta.get(QCOM_QUANTIZED_IO):
124131
continue

backends/qualcomm/tests/test_passes.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
import torch
44
from executorch.backends.qualcomm._passes import (
5+
AnnotateQuantAttrs,
56
ConvertBmmToMatmul,
67
ConvertMhaToSha,
8+
FoldQDQ,
9+
InsertIOQDQ,
710
InsertReshapeForReduceOps,
811
RemoveRedundancy,
912
)
13+
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
1014
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
1115
from executorch.backends.qualcomm.tests.models import TopKandIndex
1216
from executorch.backends.qualcomm.utils.utils import (
@@ -17,9 +21,87 @@
1721
from executorch.exir import to_edge
1822
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
1923
from executorch.exir.dialects._ops import ops as exir_ops
24+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2025

2126

2227
class TestPasses(unittest.TestCase):
28+
def _build_quantized_graph(self):
29+
"""Build a quantized graph through AnnotateQuantAttrs + FoldQDQ."""
30+
31+
class AddModule(torch.nn.Module):
32+
def forward(self, x):
33+
return x + 1
34+
35+
module = AddModule().eval()
36+
sample_input = (torch.randn(1, 4),)
37+
38+
exported = torch.export.export(module, sample_input, strict=True).module()
39+
quantizer = QnnQuantizer()
40+
quantizer.set_default_quant_config(quant_dtype=QuantDtype.use_8a8w)
41+
prepared = prepare_pt2e(exported, quantizer)
42+
prepared(*sample_input)
43+
qdq_module = convert_pt2e(prepared)
44+
45+
edge_program = to_edge(
46+
torch.export.export(qdq_module, sample_input, strict=True)
47+
)
48+
ep = edge_program.exported_program()
49+
gm = ep.graph_module
50+
51+
gm = AnnotateQuantAttrs(ep)(gm).graph_module
52+
gm = FoldQDQ(ep)(gm).graph_module
53+
return gm, ep
54+
55+
def test_insert_io_qdq_handles_dequant_encoding(self):
56+
"""InsertIOQDQ should not KeyError when a node with a dequantize
57+
encoding feeds the output node (e.g. pre-quantized LLM parameters)."""
58+
gm, ep = self._build_quantized_graph()
59+
60+
# Wire b__frozen_param0 (which has dequantize encoding) to output,
61+
# simulating the LLM topology from github issue #17732.
62+
param_node = None
63+
output_node = None
64+
for n in gm.graph.nodes:
65+
if n.name == "b__frozen_param0":
66+
param_node = n
67+
if n.op == "output":
68+
output_node = n
69+
70+
self.assertIsNotNone(param_node)
71+
old_args = output_node.args[0]
72+
output_node.args = (
73+
((old_args,) if not isinstance(old_args, tuple) else old_args)
74+
+ (param_node,),
75+
)
76+
gm.graph.lint()
77+
gm.recompile()
78+
79+
pass_instance = InsertIOQDQ(ep)
80+
pass_instance._insert(gm)
81+
82+
dq_nodes = [
83+
n
84+
for n in gm.graph.nodes
85+
if n.op == "call_function"
86+
and hasattr(n.target, "__name__")
87+
and "dequantize" in n.target.__name__
88+
and any(u.op == "output" for u in n.users.keys())
89+
]
90+
self.assertGreaterEqual(len(dq_nodes), 1)
91+
92+
def test_insert_io_qdq_no_revisit(self):
93+
"""InsertIOQDQ must not revisit newly inserted nodes."""
94+
gm, ep = self._build_quantized_graph()
95+
96+
node_count_before = len(list(gm.graph.nodes))
97+
pass_instance = InsertIOQDQ(ep)
98+
pass_instance._insert(gm)
99+
node_count_after = len(list(gm.graph.nodes))
100+
101+
# AddModule with one input and one output should insert exactly
102+
# one quantize (input) and one dequantize (output) = +2 nodes.
103+
self.assertEqual(node_count_after, node_count_before + 2)
104+
23105
def test_insert_reshape_for_argmax(self):
24106
class ArgmaxModule(torch.nn.Module):
25107
def forward(self, x):

0 commit comments

Comments
 (0)