Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,33 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
return check_op_support is None


def _gen_edge_manager_for_partitioners(
def _has_decomposable_ops(
program: "ExportedProgram",
decomp_table: dict,
) -> bool:
"""Check if any ops in the graph match the decomposition table.

Returns True if the graph contains at least one op that appears in the
decomposition table, meaning run_decompositions would actually decompose
something. Returns True for empty tables (functionalization-only path)
since we can't cheaply determine if the graph needs functionalization.
"""
if not decomp_table:
return True # empty table = functionalize, can't skip cheaply

def _graph_has_match(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target in decomp_table:
return True
for _, submod, _ in get_control_flow_submodules(gm):
if _graph_has_match(submod):
return True
return False

return _graph_has_match(program.graph_module)


def _gen_edge_manager_for_partitioners( # noqa: C901
partitioner: Dict[str, List[Partitioner]],
aten_programs: Dict[str, ExportedProgram],
config: EdgeCompileConfig,
Expand Down Expand Up @@ -1198,7 +1224,8 @@ def _gen_edge_manager_for_partitioners(
table = _default_decomposition_table()
for op in config.preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
if _has_decomposable_ops(program, table):
program = program.run_decompositions(table)

# Process each partitioner individually using their specific requirements
for curr_partitioner in partitioners_for_program:
Expand All @@ -1218,7 +1245,8 @@ def _gen_edge_manager_for_partitioners(
if table.pop(op, None) is not None:
ops_needing_preservation.append(op)

program = program.run_decompositions(table)
if _has_decomposable_ops(program, table):
program = program.run_decompositions(table)
final_ops_to_preserve.update(ops_needing_preservation)
else:
# EDGE_DO_NOT_DECOMP path for the partitioner
Expand All @@ -1232,7 +1260,8 @@ def _gen_edge_manager_for_partitioners(
table.pop(op, None)

# First pass of decompositions with this partitioner's preserved ops
program = program.run_decompositions(table)
if _has_decomposable_ops(program, table):
program = program.run_decompositions(table)

# Filter ops using EDGE_DO_NOT_DECOMP
temp_partitioner_dict = {name: [curr_partitioner]}
Expand All @@ -1245,7 +1274,9 @@ def _gen_edge_manager_for_partitioners(
final_ops_to_preserve.update(preserved_ops)

# Second pass of decompositions with this partitioner's preserved ops after filtering
program = program.run_decompositions(_default_decomposition_table())
full_table = _default_decomposition_table()
if _has_decomposable_ops(program, full_table):
program = program.run_decompositions(full_table)

# Restore ops from edge_no_decomp_namespace to aten ops
_restore_transformed_ops_to_aten_ops(program)
Expand Down
Loading