Skip to content

Commit d63ffbd

Browse files
committed
Update on "[Executorch][LLM] Use caching allocator for runner"
We observed that on iOS it improves perf by 6% because SDPA op does temp allocations. No significant difference on android though. Differential Revision: [D86120038](https://our.internmc.facebook.com/intern/diff/D86120038/) [ghstack-poisoned]
2 parents 4e0b339 + f06f5ba commit d63ffbd

137 files changed

Lines changed: 5548 additions & 1576 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@
6262
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
6363
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
6464
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
65+
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
6566
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
6667
from .decompose_round_pass import DecomposeRoundPass # noqa
6768
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
6869
from .decompose_select import DecomposeSelectPass # noqa
70+
from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa
6971
from .decompose_sign_pass import DecomposeSignPass # noqa
7072
from .decompose_silu_pass import DecomposeSiluPass # noqa
7173
from .decompose_sinh_pass import DecomposeSinhPass # noqa
@@ -115,5 +117,7 @@
115117
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
116118
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
117119
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
118-
from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip
120+
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip
121+
ReplaceInfAndLimitValuesPass,
122+
)
119123
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@
6565
DecomposeMaxPool2dPass,
6666
DecomposeMeanDimPass,
6767
DecomposeNotEqualPass,
68+
DecomposeQuantNodesPass,
6869
DecomposeRemainderPass,
6970
DecomposeRoundPass,
7071
DecomposeScaledDotProductAttentionPass,
7172
DecomposeSelectPass,
73+
DecomposeSelectScatterPass,
7274
DecomposeSignPass,
7375
DecomposeSiluPass,
7476
DecomposeSinhPass,
@@ -97,7 +99,7 @@
9799
RemoveGetItemPass,
98100
RemoveGraphAssertsPass,
99101
RemoveNoopPass,
100-
ReplaceInfValuesPass,
102+
ReplaceInfAndLimitValuesPass,
101103
ReplaceScalarWithTensorByProfilePass,
102104
RewriteConv2dPass,
103105
RewriteMatmulPass,
@@ -111,6 +113,7 @@
111113

112114
from executorch.backends.arm._passes.arm_pass import ArmPass
113115
from executorch.backends.arm.tosa.specification import (
116+
tosa_spec_in_set,
114117
TosaLoweringContext,
115118
TosaSpecification,
116119
)
@@ -172,22 +175,18 @@ def _tosa_pipeline(
172175
self.add_passes(
173176
[
174177
FuseQuantizedActivationPass(),
175-
RemoveGetItemPass(),
176178
ConvertToClampPass(),
177179
DecomposeInt32ClampPass(),
178180
DecomposeGroupNormPass(),
179181
DecomposeLayerNormPass(),
180-
DecomposeBatchNormNoStatsPass(),
181182
DecomposeVarPass(),
182183
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
183184
AnnotateDecomposedMatmulPass(),
184185
ConvertELUParamsPass(),
185-
ConvertSplitToSlicePass(),
186-
QuantizeClampArgumentsPass(),
187186
]
188187
)
189188

190-
# Fold Q/DQ nodes, insert INT8/INT32 rescales.
189+
# Fold Q/DQ nodes, insert INT8/INT32 rescales, decompose quantization nodes.
191190
self.add_passes(
192191
[
193192
FoldAndAnnotateQParamsPass(exported_program),
@@ -198,12 +197,17 @@ def _tosa_pipeline(
198197
DecomposeLinearPass(),
199198
InsertRescaleInt32Pass(),
200199
InsertControlFlowRescalesPass(),
200+
DecomposeQuantNodesPass(),
201201
]
202202
)
203203

204204
# Node transformation passes (post q/dq folding)
205205
self.add_passes(
206206
[
207+
ConvertSplitToSlicePass(),
208+
QuantizeClampArgumentsPass(),
209+
RemoveGetItemPass(),
210+
DecomposeBatchNormNoStatsPass(),
207211
DecomposeLogitPass(),
208212
DecomposeMaskedFillPass(),
209213
DecomposeRoundPass(),
@@ -240,7 +244,6 @@ def _tosa_pipeline(
240244
# passes. Ticket: MLETORCH-1540
241245
DecomposeNotEqualPass(),
242246
MatchArgRanksPass(exported_program),
243-
FuseConstantArgsPass(exported_program),
244247
]
245248
)
246249

@@ -262,6 +265,7 @@ def _tosa_pipeline(
262265
DecomposeAvgPool2dPass(),
263266
DecorateFp32toInt32CastingPass(),
264267
ComputeConstantOpsAOTPass(exported_program),
268+
FuseConstantArgsPass(exported_program),
265269
ConvertExpandCopyToRepeatPass(),
266270
UnsqueezeBeforeRepeatPass(),
267271
DecomposeCumsumPass(exported_program),
@@ -306,23 +310,28 @@ def transform_to_backend_pipeline(
306310
self, exported_program: ExportedProgram, graph_module: GraphModule
307311
):
308312
"""Apply passes before transforming program to backend"""
309-
if self.tosa_spec in (
310-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
311-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
313+
314+
if not tosa_spec_in_set(
315+
self.tosa_spec,
316+
{
317+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
318+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
319+
},
312320
):
313-
return self._tosa_pipeline(exported_program, graph_module)
314-
else:
315-
raise NotImplementedError(
316-
f"No pass pipeline implemented for {self.tosa_spec}"
321+
raise RuntimeError(
322+
f"No pass pipeline found for TOSA specification: {self.tosa_spec}"
317323
)
318324

325+
return self._tosa_pipeline(exported_program, graph_module)
326+
319327
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
320328
# Preprocessing passes
321329
self.add_pass(RemoveGraphAssertsPass())
322330

323331
# Transformation passes (pre scalar -> tensor)
324332
self.add_passes(
325333
[
334+
DecomposeSelectScatterPass(),
326335
ConvertInt64ConstOpsToInt32Pass(),
327336
ConvertInt64OutputOpsToInt32Pass(),
328337
InsertInt32CastsAfterInt64PlaceholdersPass(),
@@ -376,7 +385,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
376385
# Postprocessing passes
377386
self.add_passes(
378387
[
379-
ReplaceInfValuesPass(),
388+
ReplaceInfAndLimitValuesPass(),
380389
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
381390
]
382391
)

backends/arm/_passes/convert_elu_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def call(self, graph_module: torch.fx.GraphModule):
3838
if not is_quantized:
3939
continue
4040
with graph.inserting_after(node):
41-
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
41+
replace_node = create_node(
42+
graph, exir_ops.edge.aten.elu.default, from_node=node
43+
)
4244
old_args = list(node.args)
4345

4446
alpha = old_args[1] if len(old_args) > 1 else 1.0

backends/arm/_passes/convert_minmax_pass.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
import torch
99
from executorch.backends.arm._passes.arm_pass import ArmPass
10-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
1114
from executorch.backends.arm._passes.convert_squeezes_to_view import (
1215
ConvertSqueezesToViewPass,
1316
)
@@ -131,15 +134,21 @@ def call(self, graph_module: torch.fx.GraphModule):
131134

132135
for dim in dims:
133136
args = (input_node, dim, True)
134-
input_node = graph_module.graph.create_node(
135-
"call_function", op, args, node.kwargs
137+
input_node = create_node(
138+
graph=graph_module.graph,
139+
op_target=op,
140+
args=args,
141+
kwargs={},
142+
from_node=node,
136143
)
137144

138145
if not keepdims:
139-
input_node = graph_module.graph.create_node(
140-
"call_function",
141-
squeeze_op,
142-
(input_node, dims),
146+
input_node = create_node(
147+
graph=graph_module.graph,
148+
op_target=squeeze_op,
149+
args=(input_node, dims),
150+
kwargs={},
151+
from_node=node,
143152
)
144153

145154
replace_node.replace_all_uses_with(input_node)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,48 @@ def call(self, graph_module: torch.fx.GraphModule):
8585
graph,
8686
self.slice,
8787
(input_node, dim, starts[index], ends[index]),
88+
from_node=node,
89+
)
90+
slice_node.meta = _copy_user_node_qparams(
91+
split_node, output_node, index
8892
)
89-
slice_node.meta = split_node.meta.copy()
90-
slice_node.meta["val"] = slice_node.meta["val"][index]
9193
output_node.replace_all_uses_with(slice_node)
9294
graph.eliminate_dead_code()
9395
graph_module.recompile()
9496
graph_module = super().call(graph_module).graph_module
9597
return PassResult(graph_module, True)
98+
99+
100+
def _copy_user_node_qparams(
101+
split_node: torch.fx.Node, output_node: torch.fx.Node, index: int
102+
) -> dict:
103+
"""
104+
Construct metadata for the slice node that will replace the split output.
105+
106+
Note that output quantization parameters are copied from the user nodes
107+
of the split node. The split node itself does not have output quantization
108+
parameters.
109+
110+
Args:
111+
split_node: The split node being replaced.
112+
output_node: The getitem node that is user of the split node.
113+
index: The index of the output being processed.
114+
Returns:
115+
Updated metadata dictionary for the slice node.
116+
"""
117+
118+
def _select_index(value):
119+
if isinstance(value, (list, tuple)):
120+
return value[index]
121+
return value
122+
123+
meta = split_node.meta.copy()
124+
if "val" in meta:
125+
meta["val"] = _select_index(meta["val"])
126+
if "tensor_meta" in meta:
127+
meta["tensor_meta"] = _select_index(meta["tensor_meta"])
128+
if "input_qparams" in meta:
129+
meta["input_qparams"] = dict(meta["input_qparams"])
130+
if "output_qparams" in meta:
131+
meta["output_qparams"] = dict(output_node.meta["output_qparams"])
132+
return meta

0 commit comments

Comments
 (0)