Skip to content

Commit 5bd0889

Browse files
wanghan-iapcmHan Wang
andauthored
fix(pt): replace in-place nlist masking to enable CUDA graph capture (#5433)
## Summary - Replace `nlist[nlist == -1] = 0` in `DescrptBlockRepformers.forward` (DPA-2) and `DescrptBlockRepflows.forward` (DPA-3, two occurrences) with `nlist = torch.where(nlist == -1, 0, nlist)`. - TorchScript was lowering the indexed-assignment into an `index_put_` whose value was `torch.tensor(0, device=...)` — a per-call scalar allocation that triggers a CPU↔GPU sync and is forbidden during CUDA stream capture (`cudaErrorStreamCaptureUnsupported`). This made `forward_lower` of DPA-2 / DPA-3 models non-capturable from the LAMMPS C++ plugin, blocking the ~8–12× CUDA-graph speedup reported in #5432. - `aten::where` with a Python scalar embeds the literal in the kernel without allocating a device tensor, so the resulting frozen IR is graph-capturable. The pattern matches what `se_atten.py` and `se_t_tebd.py` already use. ## Notes - `nlist` / `a_nlist` were already rebound to local copies via `torch.where(exclude_mask != 0, ...)` earlier in both functions, so dropping the in-place mutation has no observable effect on callers. - Existing deployed `.pth` files won't benefit until re-exported (their TorchScript IR is frozen). - The reporter also flagged that the `comm_dict` ghost-atom path in `forward_lower` may still contain other CPU↔GPU sync points needed for full LAMMPS capture; that audit is out of scope here. Closes #5432. ## Test plan - [x] `pytest source/tests/pt/model/test_dpa3.py` — passes (exercises `torch.jit.script` on the DPA-3 descriptor → covers patched `repflows.py` lines) - [x] `pytest source/tests/pt/model/test_dpa2.py` — passes (numerical consistency of DPA-2 descriptor → covers patched `repformers.py` line) - [x] Inline `torch.jit.script(DescrptDPA2(...))` succeeds (verifies new `torch.where(... == -1, 0, nlist)` form scripts cleanly) - [ ] CI: full `test_jit.py` and consistency suites <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Optimized internal tensor operations in descriptor modules for improved performance and consistency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 0a481de commit 5bd0889

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

deepmd/pt/model/descriptor/repflows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ def forward(
496496
a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0)
497497
# set all padding positions to index of 0
498498
# if the a neighbor is real or not is indicated by nlist_mask
499-
nlist[nlist == -1] = 0
500-
a_nlist[a_nlist == -1] = 0
499+
nlist = torch.where(nlist == -1, 0, nlist)
500+
a_nlist = torch.where(a_nlist == -1, 0, a_nlist)
501501

502502
# get node embedding
503503
# [nframes, nloc, tebd_dim]

deepmd/pt/model/descriptor/repformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def forward(
457457

458458
# set all padding positions to index of 0
459459
# if the a neighbor is real or not is indicated by nlist_mask
460-
nlist[nlist == -1] = 0
460+
nlist = torch.where(nlist == -1, 0, nlist)
461461
# nb x nall x ng1
462462
if comm_dict is None:
463463
assert mapping is not None

0 commit comments

Comments
 (0)