Commit 5bd0889
fix(pt): replace in-place nlist masking to enable CUDA graph capture (#5433)
## Summary
- Replace `nlist[nlist == -1] = 0` in `DescrptBlockRepformers.forward`
(DPA-2) and `DescrptBlockRepflows.forward` (DPA-3, two occurrences) with
`nlist = torch.where(nlist == -1, 0, nlist)`.
- TorchScript was lowering the indexed-assignment into an `index_put_`
whose value was `torch.tensor(0, device=...)` — a per-call scalar
allocation that triggers a CPU↔GPU sync and is forbidden during CUDA
stream capture (`cudaErrorStreamCaptureUnsupported`). This made
`forward_lower` of DPA-2 / DPA-3 models non-capturable from the LAMMPS
C++ plugin, blocking the ~8–12× CUDA-graph speedup reported in #5432.
- `aten::where` with a Python scalar embeds the literal in the kernel
without allocating a device tensor, so the resulting frozen IR is
graph-capturable. The pattern matches what `se_atten.py` and
`se_t_tebd.py` already use.
## Notes
- `nlist` / `a_nlist` were already rebound to local copies via
`torch.where(exclude_mask != 0, ...)` earlier in both functions, so
dropping the in-place mutation has no observable effect on callers.
- Existing deployed `.pth` files won't benefit until re-exported (their
TorchScript IR is frozen).
- The reporter also flagged that the `comm_dict` ghost-atom path in
`forward_lower` may still contain other CPU↔GPU sync points needed for
full LAMMPS capture; that audit is out of scope here.
Closes #5432.
## Test plan
- [x] `pytest source/tests/pt/model/test_dpa3.py` — passes (exercises
`torch.jit.script` on the DPA-3 descriptor → covers patched
`repflows.py` lines)
- [x] `pytest source/tests/pt/model/test_dpa2.py` — passes (numerical
consistency of DPA-2 descriptor → covers patched `repformers.py` line)
- [x] Inline `torch.jit.script(DescrptDPA2(...))` succeeds (verifies new
`torch.where(... == -1, 0, nlist)` form scripts cleanly)
- [ ] CI: full `test_jit.py` and consistency suites
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Refactor**
* Optimized internal tensor operations in descriptor modules for
improved performance and consistency.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>1 parent 0a481de commit 5bd0889
2 files changed
Lines changed: 3 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
496 | 496 | | |
497 | 497 | | |
498 | 498 | | |
499 | | - | |
500 | | - | |
| 499 | + | |
| 500 | + | |
501 | 501 | | |
502 | 502 | | |
503 | 503 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
457 | 457 | | |
458 | 458 | | |
459 | 459 | | |
460 | | - | |
| 460 | + | |
461 | 461 | | |
462 | 462 | | |
463 | 463 | | |
| |||
0 commit comments