Skip to content

Commit a9c9746

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Minor speedup for model lowering: Skip redundant run_decompositions when no ops match decomp table (#18496)
Summary: Adds an early-exit check to _gen_edge_manager_for_partitioners: before calling program.run_decompositions(table), scan the graph for ops that appear in the decomposition table. If none are found, skip the call entirely. Each run_decompositions call performs a full re-export of the program via make_fx(), re-tracing every node through FakeTensor dispatch. On the EDGE_DO_NOT_DECOMP path this function is called up to 3 times; the early-exit eliminates at least one redundant call where the previous pass already decomposed all matching ops. The check recursively walks control flow submodules (cond/map/scan) to avoid incorrectly skipping when decomposable ops are nested. ## Benchmark Model: small CNN feature extractor (~50K params, 9 conv layers with LayerNorm, targeting Ethos-U55 via the ARM/TOSA lowering pipeline). Graph: ~1200 nodes. lower() before: 82 s lower() after: 71 s Delta: -11 s (-13 %) Differential Revision: D96489903
1 parent 9576316 commit a9c9746

1 file changed

Lines changed: 36 additions & 5 deletions

File tree

exir/program/_program.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,33 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11631163
return check_op_support is None
11641164

11651165

1166-
def _gen_edge_manager_for_partitioners(
1166+
def _has_decomposable_ops(
1167+
program: "ExportedProgram",
1168+
decomp_table: dict,
1169+
) -> bool:
1170+
"""Check if any ops in the graph match the decomposition table.
1171+
1172+
Returns True if the graph contains at least one op that appears in the
1173+
decomposition table, meaning run_decompositions would actually decompose
1174+
something. Returns True for empty tables (functionalization-only path)
1175+
since we can't cheaply determine if the graph needs functionalization.
1176+
"""
1177+
if not decomp_table:
1178+
return True # empty table = functionalize, can't skip cheaply
1179+
1180+
def _graph_has_match(gm: torch.fx.GraphModule) -> bool:
1181+
for node in gm.graph.nodes:
1182+
if node.op == "call_function" and node.target in decomp_table:
1183+
return True
1184+
for _, submod, _ in get_control_flow_submodules(gm):
1185+
if _graph_has_match(submod):
1186+
return True
1187+
return False
1188+
1189+
return _graph_has_match(program.graph_module)
1190+
1191+
1192+
def _gen_edge_manager_for_partitioners( # noqa: C901
11671193
partitioner: Dict[str, List[Partitioner]],
11681194
aten_programs: Dict[str, ExportedProgram],
11691195
config: EdgeCompileConfig,
@@ -1198,7 +1224,8 @@ def _gen_edge_manager_for_partitioners(
11981224
table = _default_decomposition_table()
11991225
for op in config.preserve_ops:
12001226
table.pop(op, None)
1201-
program = program.run_decompositions(table)
1227+
if _has_decomposable_ops(program, table):
1228+
program = program.run_decompositions(table)
12021229

12031230
# Process each partitioner individually using their specific requirements
12041231
for curr_partitioner in partitioners_for_program:
@@ -1218,7 +1245,8 @@ def _gen_edge_manager_for_partitioners(
12181245
if table.pop(op, None) is not None:
12191246
ops_needing_preservation.append(op)
12201247

1221-
program = program.run_decompositions(table)
1248+
if _has_decomposable_ops(program, table):
1249+
program = program.run_decompositions(table)
12221250
final_ops_to_preserve.update(ops_needing_preservation)
12231251
else:
12241252
# EDGE_DO_NOT_DECOMP path for the partitioner
@@ -1232,7 +1260,8 @@ def _gen_edge_manager_for_partitioners(
12321260
table.pop(op, None)
12331261

12341262
# First pass of decompositions with this partitioner's preserved ops
1235-
program = program.run_decompositions(table)
1263+
if _has_decomposable_ops(program, table):
1264+
program = program.run_decompositions(table)
12361265

12371266
# Filter ops using EDGE_DO_NOT_DECOMP
12381267
temp_partitioner_dict = {name: [curr_partitioner]}
@@ -1245,7 +1274,9 @@ def _gen_edge_manager_for_partitioners(
12451274
final_ops_to_preserve.update(preserved_ops)
12461275

12471276
# Second pass of decompositions with this partitioner's preserved ops after filtering
1248-
program = program.run_decompositions(_default_decomposition_table())
1277+
full_table = _default_decomposition_table()
1278+
if _has_decomposable_ops(program, full_table):
1279+
program = program.run_decompositions(full_table)
12491280

12501281
# Restore ops from edge_no_decomp_namespace to aten ops
12511282
_restore_transformed_ops_to_aten_ops(program)

0 commit comments

Comments
 (0)