Skip to content

Commit afa732f

Browse files
author
Han Wang
committed
fix(pt2): remove eliminate_dead_code from _strip_shape_assertions
Dead-code elimination after removing _assert_scalar nodes incorrectly removes intermediate computation nodes that share sub-expressions with the autograd gradient path, producing NaN forces for DPA1/se_atten in the NoPBC case. Remove the eliminate_dead_code() call; the leftover nodes are harmless unused scalar computations.
1 parent b0096e0 commit afa732f

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

deepmd/pt_expt/utils/serialization.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
3434
message content is not reliable. Since
3535
``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape
3636
guards into these deferred assertions, removing all of them is safe.
37+
38+
.. note::
39+
40+
We intentionally do **not** call ``graph.eliminate_dead_code()``
41+
after removing assertion nodes. Dead-code elimination can remove
42+
intermediate computation nodes that share sub-expressions with the
43+
autograd gradient path (traced via ``torch.autograd.grad`` inside the
44+
exported function). Removing those nodes produces NaN forces for
45+
models like DPA1/se_atten in the NoPBC case. The leftover "dead"
46+
nodes (computing the boolean condition for the removed assertions)
47+
are harmless — they just compute unused scalar values.
3748
"""
3849
graph = graph_module.graph
3950
for node in list(graph.nodes):
@@ -42,7 +53,6 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
4253
and node.target is torch.ops.aten._assert_scalar.default
4354
):
4455
graph.erase_node(node)
45-
graph.eliminate_dead_code()
4656
graph_module.recompile()
4757

4858

0 commit comments

Comments
 (0)