fix(pt): replace in-place nlist masking to enable CUDA graph capture#5433
Conversation
`nlist[nlist == -1] = 0` in DescrptBlockRepformers.forward and DescrptBlockRepflows.forward is lowered by TorchScript into an `index_put_` whose value is built by `torch.tensor(0, device=...)` on every forward pass. The per-call scalar allocation triggers a CPU->GPU sync, which is forbidden during CUDA stream capture (`cudaErrorStreamCaptureUnsupported`), making `forward_lower` of DPA-2 and DPA-3 models non-capturable from the LAMMPS C++ plugin. Switch to `nlist = torch.where(nlist == -1, 0, nlist)`, matching the pattern already used in se_atten.py and se_t_tebd.py. `aten::where` with a Python scalar takes the literal through the kernel without allocating a device tensor, so the resulting frozen IR is CUDA-graph capturable. `nlist`/`a_nlist` were already rebound to local copies earlier in both functions, so dropping the in-place mutation has no observable effect on callers. Closes deepmodeling#5432.
📝 WalkthroughWalkthroughPadding sentinel handling in RepFlow and Repformers descriptor forward methods is refactored from in-place index assignment to vectorized ChangesPadding Sentinel Replacement in Descriptors
Sequence DiagramsequenceDiagram
participant Caller as DPA-2 Model
participant Descriptor as Descriptor Forward
participant OldPath as Old: torch.tensor() + index_put_()
participant NewPath as New: torch.where()
participant CUDA as CUDA Graph Capture
Caller->>Descriptor: forward_lower()
Descriptor->>NewPath: Process nlist with torch.where()
NewPath->>NewPath: Replace -1 → 0 (no dynamic allocation)
NewPath-->>CUDA: Compatible with stream capture
CUDA-->>Caller: Graph cached & replayed
Note over OldPath: Previous: torch.tensor(0, device='cuda')<br/>caused cudaErrorStreamCaptureUnsupported<br/>on every forward pass
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/repflows.py (1)
619-637: Remaining per-calltorch.tensor()allocations in thecomm_dictpath.The
border_opcall sites (lines 627–636) still allocate new CPU tensors (torch.tensor(real_nloc, dtype=torch.int32, device="cpu")andtorch.tensor(real_nall - real_nloc, ...)) on every forward pass. These live on CPU, so they don't producecudaErrorStreamCaptureUnsupportedtoday, but they are still per-call host allocations inside the training loop. If CUDA graph capture of the parallel/comm_dictpath is ever needed, these will be the next sync point to address (e.g., pre-allocating atorch.zeros(1, dtype=torch.int32)buffer and filling it once, or passing Python ints directly if the op accepts them).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2fd617c5-c738-4cdc-be39-55e791b6e3d9
📒 Files selected for processing (2)
deepmd/pt/model/descriptor/repflows.pydeepmd/pt/model/descriptor/repformers.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5433 +/- ##
=======================================
Coverage 82.47% 82.47%
=======================================
Files 825 825
Lines 87721 87721
Branches 4206 4207 +1
=======================================
Hits 72344 72344
- Misses 14093 14094 +1
+ Partials 1284 1283 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
njzjz-bot
left a comment
There was a problem hiding this comment.
LGTM. I checked the diff: this only replaces in-place boolean-index assignment on nlist/a_nlist with torch.where, which avoids mutating tensors during CUDA graph capture while preserving the effective masking behavior before downstream descriptor use. No compatibility concern from my side.
Note: I could not run local uv run ruff because the current uv dependency resolution fails on mutually pinned torch CPU/GPU extras in this environment, so this approval is based on diff inspection and the PR CI status.
Reviewed by OpenClaw 2026.4.22 (00bd2cf) (model: gpt-5.5).
Authored by OpenClaw (model: gpt-5.5)
|
I wonder whether this bug is related to #5087, since I encountered the same problem with DPA3 |
|
Thank you for the suggested changes! I applied the patches to repformers.py and reflow.py, The CUDA graph model loads and runs correctly — libcudart.so.12 loads successfully and the run Performance on a small test system (~700 atoms, L40S GPU):
No speedup at this system size, which is expected — at ~700 atoms the DPA-2 attention kernels are One note for others: verify nsel matches your actual neighbour count at your system density |
Summary
nlist[nlist == -1] = 0inDescrptBlockRepformers.forward(DPA-2) andDescrptBlockRepflows.forward(DPA-3, two occurrences) withnlist = torch.where(nlist == -1, 0, nlist).index_put_whose value wastorch.tensor(0, device=...)— a per-call scalar allocation that triggers a CPU↔GPU sync and is forbidden during CUDA stream capture (cudaErrorStreamCaptureUnsupported). This madeforward_lowerof DPA-2 / DPA-3 models non-capturable from the LAMMPS C++ plugin, blocking the ~8–12× CUDA-graph speedup reported in [Feature Request]torch.tensor()insiderepformers.forward()prevents CUDA graph capture for DPA-2 models #5432.aten::wherewith a Python scalar embeds the literal in the kernel without allocating a device tensor, so the resulting frozen IR is graph-capturable. The pattern matches whatse_atten.pyandse_t_tebd.pyalready use.Notes
nlist/a_nlistwere already rebound to local copies viatorch.where(exclude_mask != 0, ...)earlier in both functions, so dropping the in-place mutation has no observable effect on callers..pthfiles won't benefit until re-exported (their TorchScript IR is frozen).comm_dictghost-atom path inforward_lowermay still contain other CPU↔GPU sync points needed for full LAMMPS capture; that audit is out of scope here.Closes #5432.
Test plan
pytest source/tests/pt/model/test_dpa3.py— passes (exercisestorch.jit.scripton the DPA-3 descriptor → covers patchedrepflows.pylines)pytest source/tests/pt/model/test_dpa2.py— passes (numerical consistency of DPA-2 descriptor → covers patchedrepformers.pyline)torch.jit.script(DescrptDPA2(...))succeeds (verifies newtorch.where(... == -1, 0, nlist)form scripts cleanly)test_jit.pyand consistency suitesSummary by CodeRabbit