Skip to content

Commit 59973be

Browse files
author
Han Wang
committed
fix(pt2): neutralise shape assertions instead of erasing them
Erasing _assert_scalar nodes from the exported FX graph (and especially calling eliminate_dead_code afterwards) disturbs the graph structure and produces NaN gradients for DPA1/se_atten in the NoPBC case on some Python/torch versions. Replace each assertion's condition with True so the node stays in the graph but never fires at runtime. This preserves the graph topology and avoids the NaN issue across all tested configurations.
1 parent afa732f commit 59973be

1 file changed

Lines changed: 9 additions & 18 deletions

File tree

deepmd/pt_expt/utils/serialization.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
20-
"""Remove shape-guard assertion nodes from an exported graph.
20+
"""Neutralise shape-guard assertion nodes in an exported graph.
2121
2222
``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape
2323
relationships discovered during tracing. These guards can be spurious:
@@ -29,30 +29,21 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
2929
like ``Ne(nnei, sum(sel))``. These are spurious because the compiled
3030
graph handles any ``nnei >= sum(sel)`` correctly.
3131
32-
The assertion messages use opaque symbolic variable names (e.g.
33-
``Ne(s22, s96)``) rather than human-readable names, so filtering by
34-
message content is not reliable. Since
35-
``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape
36-
guards into these deferred assertions, removing all of them is safe.
37-
38-
.. note::
39-
40-
We intentionally do **not** call ``graph.eliminate_dead_code()``
41-
after removing assertion nodes. Dead-code elimination can remove
42-
intermediate computation nodes that share sub-expressions with the
43-
autograd gradient path (traced via ``torch.autograd.grad`` inside the
44-
exported function). Removing those nodes produces NaN forces for
45-
models like DPA1/se_atten in the NoPBC case. The leftover "dead"
46-
nodes (computing the boolean condition for the removed assertions)
47-
are harmless — they just compute unused scalar values.
32+
Instead of erasing the assertion nodes (which can disturb the FX graph
33+
structure and produce NaN gradients on some Python/torch versions), we
34+
replace each assertion's condition with ``True`` so that the node stays
35+
in the graph but never fires at runtime.
4836
"""
4937
graph = graph_module.graph
5038
for node in list(graph.nodes):
5139
if (
5240
node.op == "call_function"
5341
and node.target is torch.ops.aten._assert_scalar.default
5442
):
55-
graph.erase_node(node)
43+
# Replace the condition with True so the assertion always passes
44+
# but the node stays in the graph. Erasing nodes can disturb the
45+
# graph structure and produce NaN on some Python/torch versions.
46+
node.args = (True, node.args[1])
5647
graph_module.recompile()
5748

5849

0 commit comments

Comments
 (0)