Skip to content

Commit 3be4546

Browse files
authored
Arm backend: Update while_loop in control-flow quant folding (#19109)
Update quantization annotation of while loop to ensure when additional inputs are provided such as threshold, quantization parameters are different. - add node.meta["additional_inputs"] to ensure when threshold is used different quantization parameters are used. - Update to quant folding to merge these if different Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 94d2881 commit 3be4546

4 files changed

Lines changed: 90 additions & 59 deletions

File tree

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
4040
return None
4141

4242

43+
def _merge_qparams(qspec_1: QuantArgs, qspec_2: QuantArgs) -> QuantArgs:
44+
"""Merge two QuantArgs when inputs are quantized differently.
45+
46+
Requires same dtype; picks the first's parameters by default.
47+
48+
"""
49+
if qspec_1.dtype != qspec_2.dtype:
50+
raise RuntimeError(
51+
f"Cannot merge qparams of different dtypes: {qspec_1.dtype} vs {qspec_2.dtype}"
52+
)
53+
return qspec_1
54+
55+
4356
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
4457
"""Get the input quantization parameters from a node, set by the
4558
'FoldAndAnnotateQParamsPass'.
@@ -121,57 +134,72 @@ def __init__(
121134
super().__init__(*args, **kwargs)
122135
self.exported_program = exported_program
123136

124-
def fold_and_annotate_arg(
125-
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
126-
) -> None:
127-
input_qparams = None
128-
nodes_to_remove = set()
137+
def _extract_input_params(
138+
self, arg_list: list[Node]
139+
) -> tuple[Optional[QuantArgs], set[Node]]:
140+
input_qparams: Optional[QuantArgs] = None
141+
nodes_to_remove: set[Node] = set()
129142
for arg in arg_list:
130143
if not isinstance(arg, Node):
131-
return
132-
133-
arg_quant_params = None
144+
return None, set()
145+
arg_quant: Optional[QuantArgs] = None
134146
if arg.target in DQ_OPS:
135147
args = arg.args
136148
scales = args[1]
137149
if (
138-
isinstance(args[1], Node)
150+
isinstance(scales, Node)
139151
and self.exported_program is not None
140-
and is_param_node(self.exported_program, args[1])
152+
and is_param_node(self.exported_program, scales)
141153
):
142-
scales = get_param_tensor(self.exported_program, args[1])
154+
scales = get_param_tensor(self.exported_program, scales)
143155
zps = args[2]
144156
if (
145-
isinstance(args[2], Node)
157+
isinstance(zps, Node)
146158
and self.exported_program is not None
147-
and is_param_node(self.exported_program, args[2])
159+
and is_param_node(self.exported_program, zps)
148160
):
149-
zps = get_param_tensor(self.exported_program, args[2])
150-
arg_quant_params = QuantArgs.from_operator(
161+
zps = get_param_tensor(self.exported_program, zps)
162+
arg_quant = QuantArgs.from_operator(
151163
arg.target, (args[0], scales, zps, *args[3:])
152164
)
153-
# add arg to nodes_to_remove to fold the dq-node
154165
nodes_to_remove.add(arg)
155-
if input_qparams is not None and input_qparams != arg_quant_params:
156-
# Two args are quantized differently
157-
raise RuntimeError("Input qparams do not match")
158-
input_qparams = arg_quant_params
159-
if input_qparams is not None:
160-
node.meta["input_qparams"][i] = input_qparams
161-
for n in nodes_to_remove:
162-
if n.target not in DQ_OPS:
163-
raise RuntimeError(
164-
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
165-
)
166+
if arg_quant is not None:
167+
if input_qparams is None:
168+
input_qparams = arg_quant
169+
elif input_qparams != arg_quant:
170+
input_qparams = _merge_qparams(input_qparams, arg_quant)
171+
return input_qparams, nodes_to_remove
172+
173+
def _annotate_input_params(
174+
self,
175+
graph_module: GraphModule,
176+
node: Node,
177+
index: int,
178+
input_qparams: QuantArgs,
179+
nodes_to_remove: set[Node],
180+
) -> None:
181+
node.meta["input_qparams"][index] = input_qparams
182+
183+
for dq in nodes_to_remove:
184+
if dq.target not in DQ_OPS:
185+
raise RuntimeError(f"Expected one of {DQ_OPS} dq_op, got {dq.target}")
186+
node.replace_input_with(dq, cast(Node, dq.args[0]))
187+
if not dq.users:
188+
graph_module.graph.erase_node(dq)
189+
190+
special = _get_special_dtype(input_qparams)
191+
if special:
192+
node.all_input_nodes[index].meta[TosaSpecialDtype.meta_key()] = special
166193

167-
node.replace_input_with(n, cast(Node, n.args[0]))
168-
if len(n.users) == 0:
169-
graph_module.graph.erase_node(n)
170-
special_dtype = _get_special_dtype(input_qparams)
171-
if special_dtype:
172-
node.all_input_nodes[i].meta[
173-
TosaSpecialDtype.meta_key()
174-
] = special_dtype
194+
def fold_and_annotate_arg(
195+
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
196+
) -> None:
197+
input_qparams, nodes_to_remove = self._extract_input_params(arg_list)
198+
if input_qparams is None:
199+
return
200+
self._annotate_input_params(
201+
graph_module, node, i, input_qparams, nodes_to_remove
202+
)
175203

176204
def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
177205
"""Fold outmost quant nodes inside submodule.

backends/arm/_passes/normalize_while_initial_args_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -82,6 +82,8 @@ def _normalize_node(self, graph_module: GraphModule, node: Node) -> bool:
8282
new_carried = tuple(carried_inputs + additional_inputs)
8383
node.update_arg(2, new_carried)
8484
node.update_arg(3, ())
85+
# annotate node so later keying of captured vs loop‐carried args is possible
86+
node.meta["additional_inputs"] = additional_inputs
8587

8688
body_module_name = str(cast(Node, node.args[1]).target)
8789
body_module = cast(GraphModule, graph_module.get_submodule(body_module_name)) # type: ignore

backends/arm/quantizer/quantization_annotator.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -890,29 +890,33 @@ def any_or_hardtanh_min_zero(n: Node):
890890
submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2
891891
submodule_args = node.args[submodule_args_pos]
892892
output_qspec = output_act_qspec
893-
if len(submodule_args) > 0: # type: ignore[arg-type]
894-
# The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a
895-
# conditional graph) need shared quantization.
896-
shared_qspec = SharedQuantizationSpec(
897-
(cast(list[Node], submodule_args)[0], node)
898-
)
899-
quant_properties.quant_inputs = [
900-
_QuantProperty(
901-
submodule_args_pos,
902-
[
903-
input_act_qspec,
904-
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
905-
],
893+
# Annotate each control-flow tensor independently using the default input qspec
894+
if submodule_args:
895+
if node.meta.get("additional_inputs", None):
896+
qspecs = [input_act_qspec] * len(cast(Sequence[Node], submodule_args)) # type: ignore[arg-type]
897+
quant_properties.quant_inputs = [
898+
_QuantProperty(submodule_args_pos, qspecs)
899+
]
900+
else:
901+
shared_qspec = SharedQuantizationSpec(
902+
(cast(list[Node], submodule_args)[0], node)
906903
)
907-
]
908-
if node.target == torch.ops.higher_order.while_loop:
909-
# The output of the while loop body can either re-enter the body, or exit the while loop.
910-
# Therefore, A and B in the diagram below need to share the same quantization parameters.
911-
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
912-
output_qspec = shared_qspec
904+
quant_properties.quant_inputs = [
905+
_QuantProperty(
906+
submodule_args_pos,
907+
[
908+
input_act_qspec,
909+
*([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type]
910+
],
911+
)
912+
]
913+
if node.target == torch.ops.higher_order.while_loop:
914+
# The output of the while loop body can either re-enter the body, or exit the while loop.
915+
# Therefore, A and B in the diagram below need to share the same quantization parameters.
916+
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
917+
output_qspec = shared_qspec
913918

914919
quant_properties.quant_output = _QuantProperty(0, output_qspec)
915-
916920
else:
917921
return None
918922

backends/arm/test/ops/test_while.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,6 @@ def test_while_loop_tosa_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
210210
@common.parametrize(
211211
"case",
212212
test_cases,
213-
xfails={
214-
"large_threshold": "MLETORCH-1808 - Handle different scales for different parameters"
215-
},
216213
)
217214
def test_while_loop_tosa_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]):
218215
module, example_inputs = case()

0 commit comments

Comments
 (0)