Skip to content

Commit f834202

Browse files
author
Han Wang
committed
fix(pt_expt): rebuild FX graph after detach node removal to avoid segfaults
After Graph.erase_node() stale C-level prev/next pointers may remain on neighbouring Node objects. Dynamo re-tracing can dereference them and segfault. Rebuild into a fresh graph to eliminate stale pointers.
1 parent bacd312 commit f834202

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,23 @@ def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None:
164164
gm.recompile()
165165

166166

167+
def _rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
168+
"""Copy all nodes into a fresh ``torch.fx.Graph``.
169+
170+
After ``Graph.erase_node()`` the C-level prev/next pointers on
171+
neighbouring ``Node`` objects may become stale. When ``torch.compile``
172+
(dynamo) later re-traces the graph it walks these pointers, which can
173+
cause segfaults. Rebuilding into a new graph eliminates stale pointers.
174+
"""
175+
old_graph = gm.graph
176+
new_graph = torch.fx.Graph()
177+
val_map: dict[torch.fx.Node, torch.fx.Node] = {}
178+
for node in old_graph.nodes:
179+
val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
180+
new_graph.lint()
181+
return torch.fx.GraphModule(gm, new_graph)
182+
183+
167184
def _trace_and_compile(
168185
model: torch.nn.Module,
169186
ext_coord: torch.Tensor,
@@ -272,6 +289,9 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
272289
# second-order gradient flow (d(force)/d(params) for force training).
273290
# Removing them restores correct higher-order derivatives.
274291
_remove_detach_nodes(traced_lower)
292+
# Rebuild into a fresh graph to eliminate stale C-level node pointers
293+
# left by erase_node(), which can cause segfaults during dynamo re-trace.
294+
traced_lower = _rebuild_graph_module(traced_lower)
275295

276296
if not was_training:
277297
model.eval()

0 commit comments

Comments
 (0)