Skip to content

Commit dd00d42

Browse files
authored
Arm backend: Fix nested control-flow partition checks (#19697)
- Updates so that the outer cond graph is picked up. - Updates to nested quantization. - Removes need for increased threshold. Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 2c9c9dd commit dd00d42

12 files changed

Lines changed: 193 additions & 121 deletions

File tree

backends/arm/_passes/arm_pass_utils.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import operator
1010
import traceback
1111
from inspect import isclass
12-
from typing import cast, List, Optional, Sequence, Tuple
12+
from typing import cast, Optional, Sequence
1313

1414
import torch
1515
import torch.fx
@@ -19,10 +19,6 @@
1919
from executorch.exir import ExportedProgram
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.dialects.edge._ops import EdgeOpOverload
22-
from executorch.exir.graph_module import (
23-
_get_control_flow_submodules,
24-
get_control_flow_submodules,
25-
)
2622
from executorch.exir.pass_base import NodeMetadata
2723

2824
from torch._export.utils import (
@@ -36,7 +32,6 @@
3632
from torch._ops import OpOverload
3733
from torch._subclasses.fake_tensor import FakeTensor
3834
from torch.export.graph_signature import InputKind
39-
from torch.fx import GraphModule, Node
4035

4136

4237
def is_submodule_node(node: torch.fx.Node):
@@ -364,48 +359,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
364359
raise RuntimeError("Invalid type")
365360

366361

367-
def is_nested_control_flow_graph(graph_module: GraphModule) -> bool:
368-
"""Returns True if graph_module is a nested control-flow graph."""
369-
370-
# Find all top-level control-flow submodules
371-
top_cf = get_control_flow_submodules(graph_module)
372-
# For each submodule, see if it itself has control-flow inside
373-
for _, submod, _ in top_cf:
374-
if get_control_flow_submodules(submod):
375-
return True
376-
return False
377-
378-
379-
def get_cond_while_submodules_nested(
380-
graph_module: GraphModule,
381-
apply_quantization: bool = False,
382-
) -> List[Tuple[str, GraphModule, Node]]:
383-
"""Recursively find cond/while_loop submodules in an GraphModule.
384-
385-
In nested control flow graphs, FX records the submodule functions
386-
(true/false or cond/body) in reverse order compared to top-level graphs. We
387-
must swap the indices when nested so that cond (first) and body/true_fn
388-
(second) are consistently identified across all nesting levels.
389-
390-
"""
391-
392-
# Determine arg indices based on nesting and whether only cond branch is needed
393-
nested = is_nested_control_flow_graph(graph_module)
394-
# cond: [true_fn, false_fn] or swapped if nested
395-
cond_indices = [2, 1] if nested else [1, 2]
396-
# while_loop: [cond_fn, body_fn] or swapped if nested
397-
while_indices = [1, 0] if nested else [0, 1]
398-
if apply_quantization:
399-
# only keep the cond_fn for while_loop (first index) when quantizing.
400-
while_indices = [while_indices[0]]
401-
mapping = {
402-
torch.ops.higher_order.cond: cond_indices,
403-
torch.ops.higher_order.while_loop: while_indices,
404-
}
405-
# collect cond/while submodules (using mapping indices)
406-
return _get_control_flow_submodules(graph_module, mapping)
407-
408-
409362
def to_2tuple(value):
410363
"""Normalizes scalars, and 1-element sequences to a tuple of length 2."""
411364
if isinstance(value, int):

backends/arm/_passes/control_flow_const_inline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
import torch
99
from executorch.backends.arm._passes.arm_pass import ArmPass
10-
from executorch.backends.arm._passes.arm_pass_utils import (
11-
get_cond_while_submodules_nested,
12-
is_submodule_node,
13-
)
10+
from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node
1411
from executorch.backends.transforms.utils import is_get_attr_node
1512
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.graph_module import get_cond_while_submodules
1614
from executorch.exir.pass_base import ExportPass, PassResult
1715
from torch.fx import GraphModule
1816

@@ -37,7 +35,7 @@ class ControlFlowConstInlinePass(ArmPass):
3735

3836
def _convert_getattr(self, graph_module):
3937
modified = False
40-
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
38+
for _, submodule, _ in get_cond_while_submodules(graph_module):
4139
for submodule_node in submodule.graph.nodes:
4240
if submodule_node.target in self._targeted_ops:
4341
self._convert_getattr(submodule)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,13 @@ def _rescale_submodule_inputs(
509509
input_node = input_nodes[qargs_index]
510510
if len(input_node.users) == 0:
511511
continue
512-
if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1:
512+
out_qparams_map = input_node.meta.get("output_qparams", {})
513+
if len(out_qparams_map) == 0:
514+
# Nested control-flow submodules may also expose frozen captured
515+
# values as placeholders. Those are not control-flow boundary
516+
# inputs, so there is no qparam pair to bridge with a RESCALE.
517+
continue
518+
if len(out_qparams_map) != 1:
513519
raise ValueError(
514520
f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}"
515521
)

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
import torch
1010
from executorch.backends.arm._passes import ArmPass
11-
from executorch.backends.arm._passes.arm_pass_utils import (
12-
get_cond_while_submodules_nested,
13-
get_first_fake_tensor,
14-
)
11+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1512
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
13+
from executorch.exir.graph_module import get_cond_while_submodules
1614
from executorch.exir.pass_base import ExportPass, PassResult
1715
from torch.fx import GraphModule, Node
1816
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
@@ -98,7 +96,7 @@ def handle_control_nodes(self, graph_module: GraphModule) -> None:
9896
"""Apply scalar argument conversion on subgraphs of control-flow
9997
nodes.
10098
"""
101-
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
99+
for _, submodule, _ in get_cond_while_submodules(graph_module):
102100
for submodule_node in submodule.graph.nodes:
103101
self._convert_scalar_args(submodule, submodule_node)
104102

backends/arm/operator_support/control_flow_support.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
from torch.fx.passes.operator_support import OperatorSupportBase
2020

2121

22+
def _owning_graph_module(node: fx.Node) -> fx.GraphModule:
23+
graph_module = getattr(node.graph, "owning_module", None)
24+
if not isinstance(graph_module, fx.GraphModule):
25+
raise RuntimeError(f"Could not resolve owning GraphModule for node {node}")
26+
return graph_module
27+
28+
2229
def _fully_partitioned(submodule: fx.GraphModule) -> bool:
2330
"""Check that all nested control-flow ops within this submodule are also
2431
fully partitioned.
@@ -27,8 +34,8 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool:
2734

2835
for submodule_node in submodule.graph.nodes:
2936
if submodule_node.target in ControlFlowOpSupported._targeted_ops:
30-
if _submodules_fully_partitioned(submodule_node, submodule):
31-
return True
37+
if not _submodules_fully_partitioned(submodule_node, submodule):
38+
return False
3239

3340
if submodule_node.op != "call_function":
3441
continue
@@ -56,13 +63,18 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool:
5663
return True
5764

5865

59-
def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool:
66+
def _submodules_fully_partitioned(
67+
node: fx.Node, graph_module: fx.GraphModule | None = None
68+
) -> bool:
6069
"""Returns whether the submodule arguments to a cond node were fully
6170
partitioned.
6271
6372
Updates "val" meta of the submodules if they are.
6473
6574
"""
75+
if graph_module is None:
76+
graph_module = _owning_graph_module(node)
77+
6678
match node.target:
6779
case torch.ops.higher_order.cond:
6880
submodule_args = node.args[1:3]
@@ -129,9 +141,7 @@ def is_node_supported(
129141
node, f"Submodule had unsupported user {user}"
130142
)
131143
return False
132-
if not _submodules_fully_partitioned(
133-
user, self.exported_program.graph_module
134-
):
144+
if not _submodules_fully_partitioned(user):
135145
self.reporter.report_reject(
136146
node, "One submodule was not fully partitioned"
137147
)
@@ -174,9 +184,7 @@ def is_node_supported(
174184
)
175185
return False
176186

177-
if not _submodules_fully_partitioned(
178-
node, self.exported_program.graph_module
179-
):
187+
if not _submodules_fully_partitioned(node):
180188
self.reporter.report_reject(
181189
node, "Submodule was not fully partitioned."
182190
)

backends/arm/operators/op_cond_if.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
validate_num_inputs,
1818
validate_valid_dtype,
1919
)
20-
from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore
20+
from executorch.backends.arm.tosa.mapping import ( # type: ignore
21+
TOSA_CONTROL_FLOW_REGION_NAME_META,
22+
TOSA_TENSOR_NAME_META,
23+
TosaArg,
24+
)
2125
from torch.fx import Node
2226

2327

@@ -38,7 +42,12 @@ def define_node(
3842
validate_cf_extension(self.target, self.tosa_spec)
3943

4044
attr = ts.TosaSerializerAttribute()
41-
if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3])
45+
if_graph, else_graph = (
46+
cast(Node, arg).meta.get(
47+
TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target)
48+
)
49+
for arg in node.args[1:3]
50+
)
4251
attr.CondIfAttribute(if_graph, else_graph)
4352

4453
self._serialize_operator(
@@ -47,7 +56,11 @@ def define_node(
4756
ts.Op.COND_IF,
4857
[
4958
inputs[0].name,
50-
*(subgraph_input.name for subgraph_input in inputs[-1].special),
59+
*(
60+
subgraph_input.name
61+
+ subgraph_input.meta.get(TOSA_TENSOR_NAME_META, "")
62+
for subgraph_input in inputs[-1].special
63+
),
5164
],
5265
output.multiple_output_names,
5366
attr,

backends/arm/operators/op_while.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
validate_cf_extension,
1616
validate_num_inputs,
1717
)
18-
from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg
18+
from executorch.backends.arm.tosa.mapping import (
19+
map_dtype,
20+
TOSA_CONTROL_FLOW_REGION_NAME_META,
21+
TOSA_TENSOR_NAME_META,
22+
TosaArg,
23+
)
1924
from executorch.backends.arm.tosa.utils import normalize_symint
25+
2026
from torch.fx import Node
2127

2228

@@ -46,7 +52,12 @@ def define_node(
4652
)
4753

4854
attr = ts.TosaSerializerAttribute()
49-
cond_graph, body_graph = (str(cast(Node, arg).target) for arg in node.args[:2])
55+
cond_graph, body_graph = (
56+
cast(Node, arg).meta.get(
57+
TOSA_CONTROL_FLOW_REGION_NAME_META, str(cast(Node, arg).target)
58+
)
59+
for arg in node.args[:2]
60+
)
5061
attr.WhileLoopAttribute(cond_graph, body_graph)
5162

5263
input_names: list[str] = []
@@ -55,7 +66,9 @@ def define_node(
5566
raise ValueError(
5667
f"{self.target}: Unsupported carried input type {type(loop_input)}."
5768
)
58-
input_names.append(loop_input.name)
69+
input_names.append(
70+
loop_input.name + loop_input.meta.get(TOSA_TENSOR_NAME_META, "")
71+
)
5972

6073
num_inputs = len(input_names)
6174
num_outputs = len(output.multiple_output_names)

0 commit comments

Comments
 (0)