Skip to content

Commit 2daf56c

Browse files
committed
fix ut
1 parent 478cba6 commit 2daf56c

5 files changed

Lines changed: 164 additions & 130 deletions

File tree

deepmd/pt/model/descriptor/sezm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ def forward(
952952
extended_coord = extended_coord.to(self.compute_dtype)
953953
nf, nloc, nnei = nlist.shape
954954
nall = extended_coord.shape[1]
955-
n_nodes = int(nf * nloc)
955+
n_nodes = nf * nloc
956956
charge_spin = self._canonicalize_charge_spin(
957957
charge_spin,
958958
nf=nf,

deepmd/pt/model/descriptor/sezm_nn/edge_cache.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def build_edge_cache(
290290
Per-edge cache.
291291
"""
292292
nf, nloc, nnei = nlist.shape
293-
n_nodes = int(nf * nloc)
293+
n_nodes = nf * nloc
294294

295295
# === Step 1. Force fp32+ for geometry ===
296296
geom_dtype = get_promoted_dtype(extended_coord.dtype)
@@ -492,10 +492,10 @@ def build_edge_cache_from_edges(
492492
edge_type_feat = edge_type_feat * edge_keep_f.to(dtype=edge_type_feat.dtype)
493493

494494
# === Step 6. Source Freeze Propagation Gate (optional) ===
495-
# The sparse-edge path packs one dummy masked edge per frame so the
496-
# compiled graph sees a statically non-empty tensor. ``edge_keep_f``
497-
# rewrites any such slot to ``w=1`` inside ``compute_edge_src_gate``,
498-
# keeping the product reduction unaffected by padding.
495+
# The sparse-edge path packs masked dummy edges so the compiled graph sees
496+
# a statically non-empty, non-singular edge tensor. ``edge_keep_f`` rewrites
497+
# any such slot to ``w=1`` inside ``compute_edge_src_gate``, keeping the
498+
# product reduction unaffected by padding.
499499
edge_src_gate: torch.Tensor | None = None
500500
if bridging_switch is not None:
501501
with nvtx_range("src_gate"):

deepmd/pt/model/model/sezm_model.py

Lines changed: 47 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@
8787
function:
8888
8989
* ``core_compute`` rebuilds a compact, GPU-friendly edge list from the
90-
padded DeePMD neighbor list (``build_edge_list_from_nlist``), with a
91-
single masked dummy edge appended so the edge tensor is never empty
92-
(NOTE 10). Edge vectors come from ``index_select`` on the extended
93-
coordinate tensor, which keeps the gradient path back to coordinates
94-
explicit and safe under symbolic shapes (NOTE 11).
90+
padded DeePMD neighbor list (``build_edge_list_from_nlist``), with
91+
masked dummy edges appended so the edge tensor has a non-singular
92+
symbolic lower bound (NOTE 10). Edge vectors come from
93+
``index_select`` on the extended coordinate tensor, which keeps the
94+
gradient path back to coordinates explicit and safe under symbolic
95+
shapes (NOTE 11).
9596
* The SeZM descriptor consumes the edge list and produces per-atom
9697
features.
9798
* The fitting network predicts per-atom energy; ``apply_out_stat`` adds
@@ -240,24 +241,22 @@
240241
cudagraphs capture autograd metadata only once. Higher-order
241242
gradients need fresh metadata per call, so cudagraphs would feed
242243
stale autograd state into the second backward.
243-
* ``max_fusion_size`` -- mode-dependent
244+
* ``max_fusion_size=8``
244245
Caps kernel fusion complexity so Inductor's scheduler does not
245246
time out on the large edge-level reductions inside the
246-
descriptor when nsel is big. Training uses ``64`` (the long-
247-
standing default, observed stable on every training run so far);
248-
inference uses the tighter ``8`` to dodge the Triton lowering
249-
failure described by the next bullet.
250-
* ``triton.persistent_reductions=False`` -- inference only
247+
descriptor when nsel is big. The tighter value keeps both
248+
training and inference fusions small enough for Triton IR
249+
generation on GPU backends that are sensitive to large dynamic
250+
edge graphs.
251+
* ``triton.persistent_reductions=False``
251252
Inductor's persistent-reduction scheduler fuses a ``sum`` with
252253
*all* neighbouring pointwise ops (``tanh_backward``, ``pow``,
253254
``exp``, ``mul``, ``select``, ``slice``, ``view`` ...) into one
254-
``triton_per_fused_...`` kernel. On the graph emitted by
255-
inference (``create_graph=False``, no double-detach stripping,
256-
different fused topology than training) this kernel hits Triton
257-
bug ``PassManager::run failed`` inside ``make_ttgir``. Training
258-
never produces the same fused shape and does not benefit from
259-
disabling the optimisation, so the flag is left on for training
260-
to preserve kernel quality.
255+
``triton_per_fused_...`` kernel. On SeZM's dynamic edge graph
256+
this can hit Triton bug ``PassManager::run failed`` inside
257+
``make_ttgir``. Disabling it forces the reduction into its own
258+
kernel before either training or inference can form the
259+
pathological fused IR.
261260
* ``triton.mix_order_reduction=False``
262261
Workaround for PyTorch <=2.11 bugs pytorch/pytorch#174379,
263262
#178080, #179494. All three manifest only under data-dependent
@@ -324,17 +323,20 @@
324323
In eval mode we merely detach; no ``create_graph`` is requested, so the
325324
compiled kernel never has to build a backward graph.
326325
327-
NOTE 10 -- Tail dummy edge
328-
--------------------------
326+
NOTE 10 -- Tail dummy edges
327+
---------------------------
329328
330-
``build_edge_list_from_nlist`` appends exactly one masked edge at the
331-
end of every batch. Real edge compaction happens via
329+
``build_edge_list_from_nlist`` appends two masked edges at the end of
330+
every batch. Real edge compaction happens via
332331
``torch.nonzero(valid_mask)``, whose output length is data-dependent
333332
and can be zero in sparse or single-type systems. make_fx cannot trace
334333
an "if n_edges == 0: skip" branch symbolically; without the dummy it
335334
would fall back to concrete shape specialization and break
336-
``dynamic=True``. The dummy's ``edge_mask`` is ``False`` so it
337-
contributes exactly zero to every downstream sum or gather.
335+
``dynamic=True``. A pair of dummy slots also gives Inductor's batched
336+
matmul lowering a static ``E >= 2`` edge-axis bound, avoiding
337+
data-dependent layout guards on ``E == 1``. Each dummy's ``edge_mask``
338+
is ``False`` so it contributes exactly zero to every downstream sum or
339+
gather.
338340
339341
NOTE 11 -- ``index_select`` for coordinate gradients
340342
----------------------------------------------------
@@ -1690,44 +1692,23 @@ def compute_fn(
16901692
# fresh graph is cheap and a segfault is fatal.
16911693
traced = _rebuild_graph_module(traced)
16921694

1693-
# NOTE: Inductor options are mode-dependent. Training has been
1694-
# running cleanly with ``max_fusion_size=64`` for a while, so we
1695-
# keep that path untouched to avoid destabilising it. Inference
1696-
# (``self.training is False``) has shown a Triton
1697-
# ``make_ttgir`` / ``PassManager::run failed`` on the fused
1698-
# per-reduction kernel
1699-
# ``triton_per_fused_clone_exp_mul_pow_select_slice_sum_tanh_...``;
1700-
# the kernel itself is fine, but the *fused* IR is too big /
1701-
# too complex for Triton's lowering pipeline on this version.
1702-
# So inference:
1703-
# * disables ``triton.persistent_reductions`` -- persistent
1704-
# reduction is what lets Inductor pull a ``sum`` together
1705-
# with all surrounding pointwise ops (including the
1706-
# activation-backward pointwise chain) into one
1707-
# ``per_fused_...`` kernel; turning it off forces the sum
1708-
# to emit its own kernel and stops the pathological fuse.
1709-
# * tightens ``max_fusion_size`` from 64 to 8, so even
1710-
# non-persistent fusions stay small enough for Triton IR
1711-
# generation to succeed.
1712-
# Training does not hit this path in practice (different graph
1713-
# topology under ``create_graph=True``), so we keep the looser
1714-
# options there to preserve kernel quality.
1695+
# NOTE: Conservative Inductor options keep SeZM's dynamic edge
1696+
# graph from forming overly large Triton reduction kernels
1697+
# (``make_ttgir`` / ``PassManager::run failed``) on some
1698+
# GPU/Triton combinations.
17151699
compile_options: dict[str, Any] = {
17161700
"max_autotune": False,
17171701
"shape_padding": True,
17181702
"epilogue_fusion": False,
17191703
"triton.cudagraphs": False,
1704+
"max_fusion_size": 8,
1705+
"triton.persistent_reductions": False,
17201706
# NOTE: ``mix_order_reduction`` hits multiple bugs under
17211707
# data-dependent symbolic shapes on PyTorch <=2.11
17221708
# (pytorch/pytorch#174379, #178080, #179494) -- our edge
17231709
# count is exactly that kind of shape.
17241710
"triton.mix_order_reduction": False,
17251711
}
1726-
if self.training:
1727-
compile_options["max_fusion_size"] = 64
1728-
else:
1729-
compile_options["max_fusion_size"] = 8
1730-
compile_options["triton.persistent_reductions"] = False
17311712
try:
17321713
from torch._inductor import config as inductor_config
17331714

@@ -1979,9 +1960,10 @@ def build_edge_list_from_nlist(
19791960
Build a compact edge list from DeePMD padded neighbor list.
19801961
19811962
Edge vectors are computed via ``index_select`` on ``extended_coord``
1982-
so they remain differentiable w.r.t. the input coordinates. One
1983-
masked dummy edge is always appended to avoid data-dependent empty-edge
1984-
branches that ``make_fx`` cannot trace.
1963+
so they remain differentiable w.r.t. the input coordinates. Two
1964+
masked dummy edges are always appended to avoid data-dependent empty-edge
1965+
branches that ``make_fx`` cannot trace and singular edge-axis guards
1966+
in Inductor's batched matmul lowering.
19851967
19861968
Parameters
19871969
----------
@@ -1995,11 +1977,11 @@ def build_edge_list_from_nlist(
19951977
Returns
19961978
-------
19971979
edge_index
1998-
Edge indices with shape (2, E+1) where E is valid edge count.
1980+
Edge indices with shape (2, E+2) where E is valid edge count.
19991981
edge_vec
2000-
Edge vectors with shape (E+1, 3).
1982+
Edge vectors with shape (E+2, 3).
20011983
edge_mask
2002-
Boolean mask with shape (E+1,). The trailing element is ``False``.
1984+
Boolean mask with shape (E+2). The trailing elements are ``False``.
20031985
"""
20041986
nf, nloc, nsel = nlist.shape
20051987
n_actual = nf * nloc
@@ -2051,19 +2033,22 @@ def build_edge_list_from_nlist(
20512033

20522034
valid_idx = torch.nonzero(edge_mask_actual, as_tuple=False).flatten()
20532035

2054-
# === Step 3. Compact edges + append one masked dummy ===
2055-
# NOTE: Always append exactly one masked dummy edge.
2036+
# === Step 3. Compact edges + append masked dummies ===
2037+
# NOTE: Always append two masked dummy edges.
20562038
# ``torch.nonzero(edge_mask_actual)`` produces a data-dependent
20572039
# number of valid edges, which can be zero on sparse or
20582040
# single-type systems. make_fx cannot trace an
20592041
# ``if n_edges == 0: skip`` branch symbolically; without the
20602042
# dummy it would fall back to concrete shape specialisation and
2061-
# break ``torch.compile(dynamic=True)`` for later batches. The
2043+
# break ``torch.compile(dynamic=True)`` for later batches. Two
2044+
# dummy edges keep the symbolic edge axis statically above one,
2045+
# which avoids Inductor bmm layout guards on ``E == 1``. Each
20622046
# dummy edge copies entry 0 (any in-range index is fine) and
20632047
# carries ``edge_mask=False`` so every downstream sum, gather
20642048
# or scatter ignores it.
2049+
dummy_count = 2
20652050
padded_idx = torch.cat(
2066-
[valid_idx, torch.zeros(1, dtype=torch.long, device=device)]
2051+
[valid_idx, torch.zeros(dummy_count, dtype=torch.long, device=device)]
20672052
)
20682053
src_sel = src_actual.index_select(0, padded_idx)
20692054
dst_sel = dst_actual.index_select(0, padded_idx)
@@ -2072,7 +2057,7 @@ def build_edge_list_from_nlist(
20722057
edge_mask = torch.cat(
20732058
[
20742059
torch.ones(valid_idx.shape[0], dtype=torch.bool, device=device),
2075-
torch.zeros(1, dtype=torch.bool, device=device),
2060+
torch.zeros(dummy_count, dtype=torch.bool, device=device),
20762061
]
20772062
)
20782063
return edge_index, edge_vec_sel, edge_mask

0 commit comments

Comments
 (0)