Commit d14233e
perf(pt2): optimize .pt2 C++ inference path (#5407)
## Summary
- Replace CPU-side `buildTypeSortedNlist` with `createNlistTensor(data,
expected_nnei)` — avoids distance computation and type sorting every
step; model's compiled `format_nlist` handles this on-device
- Export with `do_atomic_virial=False` by default — avoids 3 extra
`torch.autograd.grad` backward passes; add `--atomic-virial` flag to `dp
convert-backend`
- Cache `mapping_tensor` as member variable, only rebuild when `ago ==
0`
- Store `nnei` and `do_atomic_virial` in .pt2 metadata for C++ to read
at init
- Make nnei dynamic in torch.export — compiled graph accepts
variable-size neighbor lists via internal padding + sort branch
## Benchmark (V100-SXM2-16GB, 192-atom water, LAMMPS MD)
### Before this PR
.pt2 was **9x slower** than .pth due to CPU-side nlist sorting, baked-in
atomic virial backward passes, and excessive clones:
| Atoms | .pth (ms/step) | .pt2 (ms/step) | .pt2/.pth |
|------:|---------------:|---------------:|:---------:|
| 192 | 11 | 97 | **8.8x** |
### After this PR
#### DPA1 L0 (se_atten nlayer=0)
| Atoms | .pth (ms) | .pt2 (ms) | .pt2/.pth |
|------:|----------:|----------:|:---------:|
| 192 | 5.60 | 4.93 | **0.88x** |
| 384 | 6.69 | 8.45 | **1.26x** |
| 768 | 10.9 | 16.3 | **1.49x** |
| 1536 | 19.3 | 31.2 | **1.62x** |
| 3072 | 36.7 | 58.8 | **1.60x** |
| 6144 | 72.0 | 116 | **1.62x** |
| 12288 | 140 | 229 | **1.63x** |
#### DPA1 L2 (se_atten nlayer=2)
| Atoms | .pth (ms) | .pt2 (ms) | .pt2/.pth |
|------:|----------:|----------:|:---------:|
| 192 | 13.0 | 9.17 | **0.71x** |
| 384 | 22.2 | 16.2 | **0.73x** |
| 768 | 41.0 | 30.4 | **0.74x** |
| 1536 | 77.8 | 58.8 | **0.76x** |
#### DPA2 (repinit + repformer)
| Atoms | .pth (ms) | .pt2 (ms) | .pt2/.pth |
|------:|----------:|----------:|:---------:|
| 192 | 28.5 | 15.6 | **0.55x** |
| 384 | 34.6 | 28.2 | **0.81x** |
| 768 | 60.5 | 53.4 | **0.88x** |
| 1536 | 112.9 | 104 | **0.92x** |
For models with more compute (DPA1 L2, DPA2), .pt2 is **24-45% faster**
than .pth. For the smallest model (DPA1 L0), .pt2 has higher per-call
overhead that dominates at large atom counts.
## Test plan
- [x] All Python export/make_fx tests pass (74 tests)
- [x] All Python model tests pass
- [x] All C++ ctest pass (0 failures)
- [x] All 37 LAMMPS .pt2 tests pass
- [x] V100 benchmark confirms speedup for DPA1 L2 and DPA2
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added `--atomic-virial` command-line flag to enable atomic virial
correction during model conversion and export operations
* Models exported with this feature now include per-atom virial
contributions for improved computational accuracy
* Atomic virial support available for all exportable model formats
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>1 parent 6ec852d commit d14233e
43 files changed
Lines changed: 1360 additions & 302 deletions
File tree
- deepmd
- dpmodel/model
- entrypoints
- pt_expt
- model
- utils
- source
- api_cc
- include
- src
- tests
- lib/src
- tests
- infer
- pt_expt
- infer
- model
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
614 | 614 | | |
615 | 615 | | |
616 | 616 | | |
617 | | - | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
618 | 624 | | |
619 | 625 | | |
620 | 626 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| 11 | + | |
| 12 | + | |
10 | 13 | | |
11 | 14 | | |
12 | 15 | | |
13 | 16 | | |
14 | 17 | | |
| 18 | + | |
15 | 19 | | |
16 | 20 | | |
17 | 21 | | |
| |||
20 | 24 | | |
21 | 25 | | |
22 | 26 | | |
23 | | - | |
| 27 | + | |
24 | 28 | | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
25 | 33 | | |
26 | 34 | | |
27 | 35 | | |
28 | 36 | | |
29 | 37 | | |
30 | 38 | | |
31 | | - | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
915 | 915 | | |
916 | 916 | | |
917 | 917 | | |
| 918 | + | |
| 919 | + | |
| 920 | + | |
| 921 | + | |
| 922 | + | |
| 923 | + | |
| 924 | + | |
| 925 | + | |
| 926 | + | |
918 | 927 | | |
919 | 928 | | |
920 | 929 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
137 | 139 | | |
138 | 140 | | |
139 | 141 | | |
| 142 | + | |
140 | 143 | | |
141 | 144 | | |
142 | 145 | | |
| |||
147 | 150 | | |
148 | 151 | | |
149 | 152 | | |
150 | | - | |
151 | | - | |
152 | | - | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
117 | 119 | | |
118 | 120 | | |
119 | 121 | | |
| 122 | + | |
120 | 123 | | |
121 | 124 | | |
122 | 125 | | |
| |||
127 | 130 | | |
128 | 131 | | |
129 | 132 | | |
130 | | - | |
131 | | - | |
132 | | - | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| 23 | + | |
22 | 24 | | |
23 | 25 | | |
24 | 26 | | |
| |||
142 | 144 | | |
143 | 145 | | |
144 | 146 | | |
| 147 | + | |
145 | 148 | | |
146 | 149 | | |
147 | 150 | | |
| |||
152 | 155 | | |
153 | 156 | | |
154 | 157 | | |
155 | | - | |
156 | | - | |
157 | | - | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
158 | 168 | | |
159 | 169 | | |
160 | 170 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
139 | 141 | | |
140 | 142 | | |
141 | 143 | | |
| 144 | + | |
142 | 145 | | |
143 | 146 | | |
144 | 147 | | |
| |||
149 | 152 | | |
150 | 153 | | |
151 | 154 | | |
152 | | - | |
153 | | - | |
154 | | - | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
| 3 | + | |
3 | 4 | | |
4 | 5 | | |
5 | 6 | | |
| |||
28 | 29 | | |
29 | 30 | | |
30 | 31 | | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
31 | 54 | | |
32 | 55 | | |
33 | 56 | | |
| |||
346 | 369 | | |
347 | 370 | | |
348 | 371 | | |
| 372 | + | |
349 | 373 | | |
350 | 374 | | |
351 | 375 | | |
| |||
356 | 380 | | |
357 | 381 | | |
358 | 382 | | |
359 | | - | |
360 | | - | |
361 | | - | |
362 | | - | |
363 | | - | |
364 | | - | |
365 | | - | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
366 | 391 | | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
367 | 404 | | |
368 | 405 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| |||
117 | 119 | | |
118 | 120 | | |
119 | 121 | | |
| 122 | + | |
120 | 123 | | |
121 | 124 | | |
122 | 125 | | |
| |||
127 | 130 | | |
128 | 131 | | |
129 | 132 | | |
130 | | - | |
131 | | - | |
132 | | - | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
0 commit comments