@@ -969,6 +969,21 @@ def _unsafe_adjust_original_program( # noqa: C901
969969 Directly modify the original exported program's signature and state dict
970970 based on the consumed params/buffers in the delegate.
971971 """
972+ # First pass: identify placeholder nodes that still have users in the graph.
973+ # These cannot be deleted because they are shared between the delegate and
974+ # the remaining program (e.g., due to identity ops like no-op dropout
975+ # causing parameter aliasing across partitions).
976+ nodes_to_keep = set ()
977+ for node in original_program .graph .nodes :
978+ if node .op == "placeholder" :
979+ if node .name in input_specs_to_delete and len (node .users ) > 0 :
980+ nodes_to_keep .add (node .name )
981+ else :
982+ break
983+
984+ for name in nodes_to_keep :
985+ del input_specs_to_delete [name ]
986+
972987 original_program ._graph_signature .input_specs = [
973988 input_spec
974989 for input_spec in original_program .graph_signature .input_specs
@@ -1005,14 +1020,14 @@ def _unsafe_adjust_original_program( # noqa: C901
10051020 continue
10061021
10071022 if input_spec .kind == InputKind .PARAMETER :
1008- del original_program ._state_dict [ input_target ]
1023+ original_program ._state_dict . pop ( input_target , None )
10091024 elif input_spec .kind == InputKind .BUFFER :
10101025 if input_spec .persistent :
10111026 original_program ._state_dict .pop (input_target , None )
10121027 else :
1013- del original_program ._constants [ input_spec .target ]
1028+ original_program ._constants . pop ( input_spec .target , None )
10141029 elif input_spec .kind == InputKind .CONSTANT_TENSOR :
1015- del original_program ._constants [ input_spec .target ]
1030+ original_program ._constants . pop ( input_spec .target , None )
10161031 else :
10171032 raise RuntimeError (f"Invalid input spec { input_spec } received" )
10181033
0 commit comments