Skip to content

Commit 1621fa2

Browse files
authored
Fix broken fbcode//executorch/backends/apple/mps:test - test_emformer (execut (#20432)
Differential Revision: D109012812 Pull Request resolved: #20432
1 parent edc61ce commit 1621fa2

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

exir/lowered_backend_module.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,21 @@ def _unsafe_adjust_original_program( # noqa: C901
969969
Directly modify the original exported program's signature and state dict
970970
based on the consumed params/buffers in the delegate.
971971
"""
972+
# First pass: identify placeholder nodes that still have users in the graph.
973+
# These cannot be deleted because they are shared between the delegate and
974+
# the remaining program (e.g., due to identity ops like no-op dropout
975+
# causing parameter aliasing across partitions).
976+
nodes_to_keep = set()
977+
for node in original_program.graph.nodes:
978+
if node.op == "placeholder":
979+
if node.name in input_specs_to_delete and len(node.users) > 0:
980+
nodes_to_keep.add(node.name)
981+
else:
982+
break
983+
984+
for name in nodes_to_keep:
985+
del input_specs_to_delete[name]
986+
972987
original_program._graph_signature.input_specs = [
973988
input_spec
974989
for input_spec in original_program.graph_signature.input_specs
@@ -1005,14 +1020,14 @@ def _unsafe_adjust_original_program( # noqa: C901
10051020
continue
10061021

10071022
if input_spec.kind == InputKind.PARAMETER:
1008-
del original_program._state_dict[input_target]
1023+
original_program._state_dict.pop(input_target, None)
10091024
elif input_spec.kind == InputKind.BUFFER:
10101025
if input_spec.persistent:
10111026
original_program._state_dict.pop(input_target, None)
10121027
else:
1013-
del original_program._constants[input_spec.target]
1028+
original_program._constants.pop(input_spec.target, None)
10141029
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
1015-
del original_program._constants[input_spec.target]
1030+
original_program._constants.pop(input_spec.target, None)
10161031
else:
10171032
raise RuntimeError(f"Invalid input spec {input_spec} received")
10181033

0 commit comments

Comments
 (0)