Skip to content

Commit f1d2775

Browse files
Arm backend: Avoid repeated recompiles in constant fusion
Remove the per-fusion graph recompile from FuseConstantArgsPass. On an internal VGF quant export test, this reduced FuseConstantArgsPass time from 64.8s to 5.2s by avoiding recompiles after each fused op. The pass makes fusion decisions from the mutable FX graph. Creating a constant placeholder and calling replace_all_uses_with updates downstream node inputs immediately, so chained constant ops are still visible in the same scan. For a chain such as const -> B -> A, FX topological order visits B before A because A consumes B. Fusing B updates A's input node immediately, so A is fusable when the scan reaches it. If A is visited before B in a valid graph, B cannot be an input to A, so fusing B later cannot make A newly fusable. A final explicit recompile is not needed before super().call(...). That call interprets the updated graph and returns a freshly traced GraphModule, so the pass does not execute stale generated forward code after the rewrite. Add a regression test for chained constant folding without per-fusion recompiles. Change-Id: I65b62f23d90e663e09fcaa8f34e09ebd10eb1935 Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent 8e653a6 commit f1d2775

2 files changed

Lines changed: 61 additions & 1 deletion

File tree

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def call(self, graph_module):
204204
f"{[input_node.name for input_node in input_nodes]}"
205205
)
206206
modified |= did_fuse
207-
graph_module.recompile() # Recompile needed to catch chains of constant ops
208207
input_nodes_to_maybe_delete.update(input_nodes)
209208
except Exception as e:
210209
logger.warning(

backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,67 @@ def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None:
280280
] == ["aten_cat_default_fused_const"]
281281

282282

283+
def test_fuse_constant_args_fuses_chains_without_recompile() -> None:
284+
builder = ProgramBuilder()
285+
weight_data = torch.arange(6, dtype=torch.float32).reshape(2, 3)
286+
x_data = torch.ones(2, 3)
287+
weight = builder.placeholder(
288+
"weight",
289+
weight_data,
290+
input_kind=InputKind.CONSTANT_TENSOR,
291+
)
292+
x = builder.placeholder("x", x_data)
293+
view = builder.call_operator(
294+
exir_ops.edge.aten.view_copy.default,
295+
(weight, [3, 2]),
296+
)
297+
permute = builder.call_operator(
298+
exir_ops.edge.aten.permute_copy.default,
299+
(view, [1, 0]),
300+
)
301+
const_add = builder.call_operator(
302+
exir_ops.edge.aten.add.Tensor,
303+
(permute, 2.0),
304+
)
305+
runtime_add = builder.call_operator(
306+
exir_ops.edge.aten.add.Tensor,
307+
(const_add, x),
308+
)
309+
builder.output([runtime_add])
310+
311+
exported_program = builder.get_program()
312+
graph_module = exported_program.graph_module
313+
314+
pass_result = FuseConstantArgsPass(exported_program)(graph_module)
315+
assert pass_result is not None
316+
317+
call_targets = [
318+
node.target
319+
for node in pass_result.graph_module.graph.nodes
320+
if node.op == "call_function"
321+
]
322+
assert exir_ops.edge.aten.view_copy.default not in call_targets
323+
assert exir_ops.edge.aten.permute_copy.default not in call_targets
324+
assert call_targets.count(exir_ops.edge.aten.add.Tensor) == 1
325+
326+
graph_args = []
327+
for node in pass_result.graph_module.graph.nodes:
328+
if node.op != "placeholder":
329+
continue
330+
if node.name == "x":
331+
graph_args.append(x_data)
332+
elif node.name in exported_program.state_dict:
333+
graph_args.append(exported_program.state_dict[node.name])
334+
else:
335+
graph_args.append(cast(torch.Tensor, exported_program.constants[node.name]))
336+
337+
actual = pass_result.graph_module(*graph_args)
338+
if isinstance(actual, (list, tuple)):
339+
actual = actual[0]
340+
expected = weight_data.view(3, 2).permute(1, 0) + 2.0 + x_data
341+
torch.testing.assert_close(actual, expected)
342+
343+
283344
def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None:
284345
class FakeTosaTarget:
285346
def __str__(self) -> str:

0 commit comments

Comments
 (0)