Skip to content

Commit 0e196c3

Browse files
authored
Merge branch 'main' into dynamic_unbound_kv_cache
2 parents 21200ba + 4741f3a commit 0e196c3

64 files changed

Lines changed: 2175 additions & 552 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,14 @@ def _tosa_pipeline(
481481
ConvertFullLikeToFullPass(),
482482
MatchArgDtypePass(),
483483
UnsqueezeScalarPlaceholdersPass(exported_program),
484-
# TODO: Move DecomposeNotEqualPass to before or after this block of
485-
# passes. Ticket: MLETORCH-1540
486-
DecomposeNotEqualPass(),
487484
MatchArgRanksPass(exported_program),
488485
]
489486
)
490487

491488
# Node transformation passes (post scalar-removal)
492489
self.add_passes(
493490
[
491+
DecomposeNotEqualPass(),
494492
NormalizeIndexPutNoneIndicesPass(),
495493
NormalizeIndexPutBoolIndexTensorPass(),
496494
RewriteIndexPutPass(),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import operator
1010
import traceback
1111
from inspect import isclass
12-
from typing import cast, List, Optional, Sequence, Tuple
12+
from typing import cast, Optional, Sequence
1313

1414
import torch
1515
import torch.fx
@@ -19,10 +19,6 @@
1919
from executorch.exir import ExportedProgram
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.dialects.edge._ops import EdgeOpOverload
22-
from executorch.exir.graph_module import (
23-
_get_control_flow_submodules,
24-
get_control_flow_submodules,
25-
)
2622
from executorch.exir.pass_base import NodeMetadata
2723

2824
from torch._export.utils import (
@@ -36,7 +32,6 @@
3632
from torch._ops import OpOverload
3733
from torch._subclasses.fake_tensor import FakeTensor
3834
from torch.export.graph_signature import InputKind
39-
from torch.fx import GraphModule, Node
4035

4136

4237
def is_submodule_node(node: torch.fx.Node):
@@ -364,48 +359,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
364359
raise RuntimeError("Invalid type")
365360

366361

367-
def is_nested_control_flow_graph(graph_module: GraphModule) -> bool:
368-
"""Returns True if graph_module is a nested control-flow graph."""
369-
370-
# Find all top-level control-flow submodules
371-
top_cf = get_control_flow_submodules(graph_module)
372-
# For each submodule, see if it itself has control-flow inside
373-
for _, submod, _ in top_cf:
374-
if get_control_flow_submodules(submod):
375-
return True
376-
return False
377-
378-
379-
def get_cond_while_submodules_nested(
380-
graph_module: GraphModule,
381-
apply_quantization: bool = False,
382-
) -> List[Tuple[str, GraphModule, Node]]:
383-
"""Recursively find cond/while_loop submodules in an GraphModule.
384-
385-
In nested control flow graphs, FX records the submodule functions
386-
(true/false or cond/body) in reverse order compared to top-level graphs. We
387-
must swap the indices when nested so that cond (first) and body/true_fn
388-
(second) are consistently identified across all nesting levels.
389-
390-
"""
391-
392-
# Determine arg indices based on nesting and whether only cond branch is needed
393-
nested = is_nested_control_flow_graph(graph_module)
394-
# cond: [true_fn, false_fn] or swapped if nested
395-
cond_indices = [2, 1] if nested else [1, 2]
396-
# while_loop: [cond_fn, body_fn] or swapped if nested
397-
while_indices = [1, 0] if nested else [0, 1]
398-
if apply_quantization:
399-
# only keep the cond_fn for while_loop (first index) when quantizing.
400-
while_indices = [while_indices[0]]
401-
mapping = {
402-
torch.ops.higher_order.cond: cond_indices,
403-
torch.ops.higher_order.while_loop: while_indices,
404-
}
405-
# collect cond/while submodules (using mapping indices)
406-
return _get_control_flow_submodules(graph_module, mapping)
407-
408-
409362
def to_2tuple(value):
410363
"""Normalizes scalars, and 1-element sequences to a tuple of length 2."""
411364
if isinstance(value, int):

backends/arm/_passes/control_flow_const_inline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
import torch
99
from executorch.backends.arm._passes.arm_pass import ArmPass
10-
from executorch.backends.arm._passes.arm_pass_utils import (
11-
get_cond_while_submodules_nested,
12-
is_submodule_node,
13-
)
10+
from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node
1411
from executorch.backends.transforms.utils import is_get_attr_node
1512
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.graph_module import get_cond_while_submodules
1614
from executorch.exir.pass_base import ExportPass, PassResult
1715
from torch.fx import GraphModule
1816

@@ -37,7 +35,7 @@ class ControlFlowConstInlinePass(ArmPass):
3735

3836
def _convert_getattr(self, graph_module):
3937
modified = False
40-
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
38+
for _, submodule, _ in get_cond_while_submodules(graph_module):
4139
for submodule_node in submodule.graph.nodes:
4240
if submodule_node.target in self._targeted_ops:
4341
self._convert_getattr(submodule)

backends/arm/_passes/insert_const_shapes.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class InsertConstShapesPass(ArmPass):
2626
exir_ops.edge.aten.repeat.default,
2727
}
2828

29+
def __init__(self) -> None:
30+
super().__init__()
31+
self._const_shape_cache: dict[tuple[int, ...], Any] = {}
32+
2933
@staticmethod
3034
def _is_shape_arg(arg: Any) -> bool:
3135
"""Return True when `arg` looks like a literal shape list/tuple."""
@@ -46,13 +50,17 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False)
4650
# Insert a const node for the shape argument
4751
if op == exir_ops.edge.aten.view_copy.default:
4852
arg = meta.data["val"].shape
49-
const_node = super().call_shape_operator(
50-
exir_ops.backend.tosa.CONST_SHAPE.default,
51-
(arg,),
52-
{},
53-
meta,
54-
True,
55-
)
53+
shape = tuple(arg)
54+
const_node = self._const_shape_cache.get(shape)
55+
if const_node is None:
56+
const_node = super().call_shape_operator(
57+
exir_ops.backend.tosa.CONST_SHAPE.default,
58+
(arg,),
59+
{},
60+
meta,
61+
True,
62+
)
63+
self._const_shape_cache[shape] = const_node
5664
new_args.append(const_node)
5765
updated = True
5866
else:

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,13 @@ def _rescale_submodule_inputs(
509509
input_node = input_nodes[qargs_index]
510510
if len(input_node.users) == 0:
511511
continue
512-
if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1:
512+
out_qparams_map = input_node.meta.get("output_qparams", {})
513+
if len(out_qparams_map) == 0:
514+
# Nested control-flow submodules may also expose frozen captured
515+
# values as placeholders. Those are not control-flow boundary
516+
# inputs, so there is no qparam pair to bridge with a RESCALE.
517+
continue
518+
if len(out_qparams_map) != 1:
513519
raise ValueError(
514520
f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}"
515521
)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5757
exir_ops.edge.aten.ge.Tensor,
5858
exir_ops.edge.aten.lt.Tensor,
5959
exir_ops.edge.aten.le.Tensor,
60+
exir_ops.edge.aten.ne.Tensor,
6061
exir_ops.edge.aten.pow.Tensor_Tensor,
6162
exir_ops.edge.aten.remainder.Tensor,
6263
exir_ops.edge.aten.where.self,

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
import torch
1010
from executorch.backends.arm._passes import ArmPass
11-
from executorch.backends.arm._passes.arm_pass_utils import (
12-
get_cond_while_submodules_nested,
13-
get_first_fake_tensor,
14-
)
11+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1512
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
13+
from executorch.exir.graph_module import get_cond_while_submodules
1614
from executorch.exir.pass_base import ExportPass, PassResult
1715
from torch.fx import GraphModule, Node
1816
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
@@ -98,7 +96,7 @@ def handle_control_nodes(self, graph_module: GraphModule) -> None:
9896
"""Apply scalar argument conversion on subgraphs of control-flow
9997
nodes.
10098
"""
101-
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
99+
for _, submodule, _ in get_cond_while_submodules(graph_module):
102100
for submodule_node in submodule.graph.nodes:
103101
self._convert_scalar_args(submodule, submodule_node)
104102

backends/arm/operator_support/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ runtime.python_library(
66
deps = [
77
"//executorch/backends/arm:constants",
88
"//executorch/backends/arm/_passes:passes",
9+
"//executorch/backends/arm/tosa:resize_utils",
910
"//executorch/backends/arm/tosa:tosa",
1011
"//executorch/backends/transforms:remove_getitem_op",
1112
"//executorch/backends/xnnpack/_passes:xnnpack_passes",

backends/arm/operator_support/control_flow_support.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
from torch.fx.passes.operator_support import OperatorSupportBase
2020

2121

22+
def _owning_graph_module(node: fx.Node) -> fx.GraphModule:
23+
graph_module = getattr(node.graph, "owning_module", None)
24+
if not isinstance(graph_module, fx.GraphModule):
25+
raise RuntimeError(f"Could not resolve owning GraphModule for node {node}")
26+
return graph_module
27+
28+
2229
def _fully_partitioned(submodule: fx.GraphModule) -> bool:
2330
"""Check that all nested control-flow ops within this submodule are also
2431
fully partitioned.
@@ -27,8 +34,8 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool:
2734

2835
for submodule_node in submodule.graph.nodes:
2936
if submodule_node.target in ControlFlowOpSupported._targeted_ops:
30-
if _submodules_fully_partitioned(submodule_node, submodule):
31-
return True
37+
if not _submodules_fully_partitioned(submodule_node, submodule):
38+
return False
3239

3340
if submodule_node.op != "call_function":
3441
continue
@@ -56,13 +63,18 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool:
5663
return True
5764

5865

59-
def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool:
66+
def _submodules_fully_partitioned(
67+
node: fx.Node, graph_module: fx.GraphModule | None = None
68+
) -> bool:
6069
"""Returns whether the submodule arguments to a cond node were fully
6170
partitioned.
6271
6372
Updates "val" meta of the submodules if they are.
6473
6574
"""
75+
if graph_module is None:
76+
graph_module = _owning_graph_module(node)
77+
6678
match node.target:
6779
case torch.ops.higher_order.cond:
6880
submodule_args = node.args[1:3]
@@ -129,9 +141,7 @@ def is_node_supported(
129141
node, f"Submodule had unsupported user {user}"
130142
)
131143
return False
132-
if not _submodules_fully_partitioned(
133-
user, self.exported_program.graph_module
134-
):
144+
if not _submodules_fully_partitioned(user):
135145
self.reporter.report_reject(
136146
node, "One submodule was not fully partitioned"
137147
)
@@ -174,9 +184,7 @@ def is_node_supported(
174184
)
175185
return False
176186

177-
if not _submodules_fully_partitioned(
178-
node, self.exported_program.graph_module
179-
):
187+
if not _submodules_fully_partitioned(node):
180188
self.reporter.report_reject(
181189
node, "Submodule was not fully partitioned."
182190
)

backends/arm/operator_support/upsample_support.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,65 @@
1313
SupportedTOSAOperatorCheck,
1414
)
1515
from executorch.backends.arm.tosa import TosaSpecification
16+
from executorch.backends.arm.tosa.resize_utils import get_tosa_resize_validation_error
1617
from executorch.exir.dialects._ops import ops as exir_ops
1718

1819

20+
def _is_upsample_node_tosa_supported(
21+
support_check: SupportedTOSAOperatorCheck,
22+
node: fx.Node,
23+
tosa_spec: TosaSpecification,
24+
*,
25+
align_corners: bool,
26+
) -> bool:
27+
input_node = ensure_type(fx.Node, node.args[0])
28+
input_size_yx = get_first_fake_tensor(input_node).shape[2:]
29+
output_size_yx = get_first_fake_tensor(node).shape[2:]
30+
31+
try:
32+
scale_y_n, scale_y_d, offset_y, border_y = (
33+
RewriteUpsamplePass.get_resize_parameters_1d(
34+
input_size_yx[0], output_size_yx[0], align_corners
35+
)
36+
)
37+
scale_x_n, scale_x_d, offset_x, border_x = (
38+
RewriteUpsamplePass.get_resize_parameters_1d(
39+
input_size_yx[1], output_size_yx[1], align_corners
40+
)
41+
)
42+
except RuntimeError as err:
43+
support_check.reporter.report_reject(node, str(err))
44+
return False
45+
46+
# Validate the exact TOSA RESIZE parameters that RewriteUpsamplePass will
47+
# emit so support checks and fake-op validation reject the same cases.
48+
validation_error = get_tosa_resize_validation_error(
49+
input_hw=input_size_yx,
50+
output_hw=output_size_yx,
51+
scale=[scale_y_n, scale_y_d, scale_x_n, scale_x_d],
52+
offset=[offset_y, offset_x],
53+
border=[border_y, border_x],
54+
tosa_spec=tosa_spec,
55+
)
56+
if validation_error is not None:
57+
support_check.reporter.report_reject(node, validation_error)
58+
return False
59+
60+
return True
61+
62+
1963
@register_tosa_support_check
2064
class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck):
2165
"""Provide the explicit TOSA support gate for nearest upsample."""
2266

2367
targets = [exir_ops.edge.aten.upsample_nearest2d.vec]
2468

2569
def is_node_tosa_supported(
26-
self, _node: fx.Node, _tosa_spec: TosaSpecification
70+
self, node: fx.Node, tosa_spec: TosaSpecification
2771
) -> bool: # type: ignore[override, misc]
28-
return True
72+
return _is_upsample_node_tosa_supported(
73+
self, node, tosa_spec, align_corners=False
74+
)
2975

3076

3177
@register_tosa_support_check
@@ -37,33 +83,9 @@ class UpsampleBilinear2dSupported(SupportedTOSAOperatorCheck):
3783
targets = [exir_ops.edge.aten.upsample_bilinear2d.vec]
3884

3985
def is_node_tosa_supported(
40-
self, node: fx.Node, _tosa_spec: TosaSpecification
86+
self, node: fx.Node, tosa_spec: TosaSpecification
4187
) -> bool: # type: ignore[override, misc]
42-
input_node = ensure_type(fx.Node, node.args[0])
4388
align_corners = ensure_type(bool, node.args[2])
44-
input_size_yx = get_first_fake_tensor(input_node).shape[2:]
45-
output_size_yx = get_first_fake_tensor(node).shape[2:]
46-
47-
try:
48-
scale_y_n, scale_y_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d(
49-
input_size_yx[0], output_size_yx[0], align_corners
50-
)
51-
scale_x_n, scale_x_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d(
52-
input_size_yx[1], output_size_yx[1], align_corners
53-
)
54-
except RuntimeError as err:
55-
self.reporter.report_reject(node, str(err))
56-
return False
57-
58-
# get_resize_parameters_1d() returns the TOSA RESIZE scale fraction for
59-
# each spatial dimension. For align_corners=False, this is the effective
60-
# output_size / input_size ratio, so the 1/16 boundary is checked
61-
# directly in the same representation that RESIZE lowering will use.
62-
if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n:
63-
self.reporter.report_reject(
64-
node,
65-
"Bilinear RESIZE downscale must be strictly greater than 1/16",
66-
)
67-
return False
68-
69-
return True
89+
return _is_upsample_node_tosa_supported(
90+
self, node, tosa_spec, align_corners=align_corners
91+
)

0 commit comments

Comments
 (0)