Skip to content

Commit d14233e

Browse files
wanghan-iapcmHan Wang
andauthored
perf(pt2): optimize .pt2 C++ inference path (#5407)
## Summary - Replace CPU-side `buildTypeSortedNlist` with `createNlistTensor(data, expected_nnei)` — avoids distance computation and type sorting every step; model's compiled `format_nlist` handles this on-device - Export with `do_atomic_virial=False` by default — avoids 3 extra `torch.autograd.grad` backward passes; add `--atomic-virial` flag to `dp convert-backend` - Cache `mapping_tensor` as member variable, only rebuild when `ago == 0` - Store `nnei` and `do_atomic_virial` in .pt2 metadata for C++ to read at init - Make nnei dynamic in torch.export — compiled graph accepts variable-size neighbor lists via internal padding + sort branch ## Benchmark (V100-SXM2-16GB, 192-atom water, LAMMPS MD) ### Before this PR .pt2 was **9x slower** than .pth due to CPU-side nlist sorting, baked-in atomic virial backward passes, and excessive clones: | Atoms | .pth (ms/step) | .pt2 (ms/step) | .pt2/.pth | |------:|---------------:|---------------:|:---------:| | 192 | 11 | 97 | **8.8x** | ### After this PR #### DPA1 L0 (se_atten nlayer=0) | Atoms | .pth (ms) | .pt2 (ms) | .pt2/.pth | |------:|----------:|----------:|:---------:| | 192 | 5.60 | 4.93 | **0.88x** | | 384 | 6.69 | 8.45 | **1.26x** | | 768 | 10.9 | 16.3 | **1.49x** | | 1536 | 19.3 | 31.2 | **1.62x** | | 3072 | 36.7 | 58.8 | **1.60x** | | 6144 | 72.0 | 116 | **1.62x** | | 12288 | 140 | 229 | **1.63x** | #### DPA1 L2 (se_atten nlayer=2) | Atoms | .pth (ms) | .pt2 (ms) | .pt2/.pth | |------:|----------:|----------:|:---------:| | 192 | 13.0 | 9.17 | **0.71x** | | 384 | 22.2 | 16.2 | **0.73x** | | 768 | 41.0 | 30.4 | **0.74x** | | 1536 | 77.8 | 58.8 | **0.76x** | #### DPA2 (repinit + repformer) | Atoms | .pth (ms) | .pt2 (ms) | .pt2/.pth | |------:|----------:|----------:|:---------:| | 192 | 28.5 | 15.6 | **0.55x** | | 384 | 34.6 | 28.2 | **0.81x** | | 768 | 60.5 | 53.4 | **0.88x** | | 1536 | 112.9 | 104 | **0.92x** | For models with more compute (DPA1 L2, DPA2), .pt2 is **24-45% faster** than .pth. For the smallest model (DPA1 L0), .pt2 has higher per-call overhead that dominates at large atom counts. ## Test plan - [x] All Python export/make_fx tests pass (74 tests) - [x] All Python model tests pass - [x] All C++ ctest pass (0 failures) - [x] All 37 LAMMPS .pt2 tests pass - [x] V100 benchmark confirms speedup for DPA1 L2 and DPA2 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--atomic-virial` command-line flag to enable atomic virial correction during model conversion and export operations * Models exported with this feature now include per-atom virial contributions for improved computational accuracy * Atomic virial support available for all exportable model formats <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 6ec852d commit d14233e

43 files changed

Lines changed: 1360 additions & 302 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/dpmodel/model/make_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,13 @@ def _format_nlist(
614614
axis=-1,
615615
)
616616

617-
if n_nnei > nnei or extra_nlist_sort:
617+
# Order matters for torch.export: Python evaluates `or` left-to-right
618+
# with short-circuit. When `extra_nlist_sort=True` (Python bool) is
619+
# on the left, the right-hand `n_nnei > nnei` is not evaluated, so no
620+
# symbolic guard is registered on the dynamic `n_nnei` dimension.
621+
# Swapping the operands would force the SymInt comparison to run and
622+
# emit an `_assert_scalar` node in the exported graph.
623+
if extra_nlist_sort or n_nnei > nnei:
618624
n_nf, n_nloc, n_nnei = nlist.shape
619625
# make a copy before revise
620626
m_real_nei = nlist >= 0
Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
23
from typing import (
34
Any,
45
)
@@ -7,11 +8,14 @@
78
Backend,
89
)
910

11+
log = logging.getLogger(__name__)
12+
1013

1114
def convert_backend(
1215
*, # Enforce keyword-only arguments
1316
INPUT: str,
1417
OUTPUT: str,
18+
atomic_virial: bool = False,
1519
**kwargs: Any,
1620
) -> None:
1721
"""Convert a model file from one backend to another.
@@ -20,12 +24,31 @@ def convert_backend(
2024
----------
2125
INPUT : str
2226
The input model file.
23-
INPUT : str
27+
OUTPUT : str
2428
The output model file.
29+
atomic_virial : bool
30+
If True, export .pt2/.pte models with per-atom virial correction.
31+
This adds ~2.5x inference cost. Default False. Silently ignored
32+
(with a warning) for backends that don't support the flag.
2533
"""
2634
inp_backend: Backend = Backend.detect_backend_by_model(INPUT)()
2735
out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)()
2836
inp_hook = inp_backend.serialize_hook
2937
out_hook = out_backend.deserialize_hook
3038
data = inp_hook(INPUT)
31-
out_hook(OUTPUT, data)
39+
# Forward atomic_virial to pt_expt deserialize_to_file if applicable;
40+
# warn and skip the flag for backends that don't accept it so that
41+
# scripts passing --atomic-virial indiscriminately don't break.
42+
import inspect
43+
44+
sig = inspect.signature(out_hook)
45+
if "do_atomic_virial" in sig.parameters:
46+
out_hook(OUTPUT, data, do_atomic_virial=atomic_virial)
47+
else:
48+
if atomic_virial:
49+
log.warning(
50+
"--atomic-virial is only meaningful for pt_expt .pt2/.pte "
51+
"outputs; ignoring it for output backend %s",
52+
out_backend.name,
53+
)
54+
out_hook(OUTPUT, data)

deepmd/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,15 @@ def main_parser() -> argparse.ArgumentParser:
915915
)
916916
parser_convert_backend.add_argument("INPUT", help="The input model file.")
917917
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")
918+
parser_convert_backend.add_argument(
919+
"--atomic-virial",
920+
action="store_true",
921+
default=False,
922+
help="Export .pt2/.pte models with per-atom virial correction. "
923+
"This adds ~2.5x inference cost but is required for "
924+
"LAMMPS compute/atom virial output. "
925+
"Ignored (with a warning) for other output backends.",
926+
)
918927

919928
# * show model ******************************************************************
920929
parser_show = subparsers.add_parser(

deepmd/pt_expt/model/dipole_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import types
23
from typing import (
34
Any,
45
)
@@ -16,6 +17,7 @@
1617
)
1718

1819
from .make_model import (
20+
_pad_nlist_for_export,
1921
make_model,
2022
)
2123
from .model import (
@@ -137,6 +139,7 @@ def fn(
137139
aparam: torch.Tensor | None,
138140
) -> dict[str, torch.Tensor]:
139141
extended_coord = extended_coord.detach().requires_grad_(True)
142+
nlist = _pad_nlist_for_export(nlist)
140143
return model.forward_lower(
141144
extended_coord,
142145
extended_atype,
@@ -147,6 +150,13 @@ def fn(
147150
do_atomic_virial=do_atomic_virial,
148151
)
149152

150-
return make_fx(fn, **make_fx_kwargs)(
151-
extended_coord, extended_atype, nlist, mapping, fparam, aparam
152-
)
153+
# See make_model.py for the rationale of the pad + monkeypatch.
154+
_orig_need_sort = model.need_sorted_nlist_for_lower
155+
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
156+
try:
157+
traced = make_fx(fn, **make_fx_kwargs)(
158+
extended_coord, extended_atype, nlist, mapping, fparam, aparam
159+
)
160+
finally:
161+
model.need_sorted_nlist_for_lower = _orig_need_sort
162+
return traced

deepmd/pt_expt/model/dos_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import types
23
from typing import (
34
Any,
45
)
@@ -16,6 +17,7 @@
1617
)
1718

1819
from .make_model import (
20+
_pad_nlist_for_export,
1921
make_model,
2022
)
2123
from .model import (
@@ -117,6 +119,7 @@ def fn(
117119
aparam: torch.Tensor | None,
118120
) -> dict[str, torch.Tensor]:
119121
extended_coord = extended_coord.detach().requires_grad_(True)
122+
nlist = _pad_nlist_for_export(nlist)
120123
return model.forward_lower(
121124
extended_coord,
122125
extended_atype,
@@ -127,6 +130,13 @@ def fn(
127130
do_atomic_virial=do_atomic_virial,
128131
)
129132

130-
return make_fx(fn, **make_fx_kwargs)(
131-
extended_coord, extended_atype, nlist, mapping, fparam, aparam
132-
)
133+
# See make_model.py for the rationale of the pad + monkeypatch.
134+
_orig_need_sort = model.need_sorted_nlist_for_lower
135+
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
136+
try:
137+
traced = make_fx(fn, **make_fx_kwargs)(
138+
extended_coord, extended_atype, nlist, mapping, fparam, aparam
139+
)
140+
finally:
141+
model.need_sorted_nlist_for_lower = _orig_need_sort
142+
return traced

deepmd/pt_expt/model/dp_linear_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import types
23
from typing import (
34
Any,
45
)
@@ -19,6 +20,7 @@
1920
)
2021

2122
from .make_model import (
23+
_pad_nlist_for_export,
2224
make_model,
2325
)
2426
from .model import (
@@ -142,6 +144,7 @@ def fn(
142144
aparam: torch.Tensor | None,
143145
) -> dict[str, torch.Tensor]:
144146
extended_coord = extended_coord.detach().requires_grad_(True)
147+
nlist = _pad_nlist_for_export(nlist)
145148
return model.forward_lower(
146149
extended_coord,
147150
extended_atype,
@@ -152,9 +155,16 @@ def fn(
152155
do_atomic_virial=do_atomic_virial,
153156
)
154157

155-
return make_fx(fn, **make_fx_kwargs)(
156-
extended_coord, extended_atype, nlist, mapping, fparam, aparam
157-
)
158+
# See make_model.py for the rationale of the pad + monkeypatch.
159+
_orig_need_sort = model.need_sorted_nlist_for_lower
160+
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
161+
try:
162+
traced = make_fx(fn, **make_fx_kwargs)(
163+
extended_coord, extended_atype, nlist, mapping, fparam, aparam
164+
)
165+
finally:
166+
model.need_sorted_nlist_for_lower = _orig_need_sort
167+
return traced
158168

159169
@classmethod
160170
def update_sel(

deepmd/pt_expt/model/dp_zbl_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import types
23
from typing import (
34
Any,
45
)
@@ -16,6 +17,7 @@
1617
)
1718

1819
from .make_model import (
20+
_pad_nlist_for_export,
1921
make_model,
2022
)
2123
from .model import (
@@ -139,6 +141,7 @@ def fn(
139141
aparam: torch.Tensor | None,
140142
) -> dict[str, torch.Tensor]:
141143
extended_coord = extended_coord.detach().requires_grad_(True)
144+
nlist = _pad_nlist_for_export(nlist)
142145
return model.forward_lower(
143146
extended_coord,
144147
extended_atype,
@@ -149,6 +152,15 @@ def fn(
149152
do_atomic_virial=do_atomic_virial,
150153
)
151154

152-
return make_fx(fn, **make_fx_kwargs)(
153-
extended_coord, extended_atype, nlist, mapping, fparam, aparam
154-
)
155+
# Force `_format_nlist`'s sort branch into the compiled graph so the
156+
# exported model tolerates oversized nlists at runtime — see
157+
# make_model.py for the full rationale.
158+
_orig_need_sort = model.need_sorted_nlist_for_lower
159+
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
160+
try:
161+
traced = make_fx(fn, **make_fx_kwargs)(
162+
extended_coord, extended_atype, nlist, mapping, fparam, aparam
163+
)
164+
finally:
165+
model.need_sorted_nlist_for_lower = _orig_need_sort
166+
return traced

deepmd/pt_expt/model/make_model.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import math
3+
import types
34
from typing import (
45
Any,
56
)
@@ -28,6 +29,28 @@
2829
)
2930

3031

32+
def _pad_nlist_for_export(nlist: torch.Tensor) -> torch.Tensor:
33+
"""Append a single ``-1`` column to ``nlist`` for export-time tracing.
34+
35+
Used inside ``forward_common_lower_exportable`` (and its spin counterpart)
36+
so that ``_format_nlist``'s terminal slice ``ret[..., :nnei]`` truncates
37+
to a statically sized output. Without the extra column, torch.export
38+
cannot prove the ``ret.shape[-1] == nnei`` assertion at trace time and
39+
would specialise the dynamic ``nnei`` dim to the sample value.
40+
41+
Combined with the short-circuit order in ``_format_nlist``
42+
(``extra_nlist_sort`` on the left) and the ``need_sorted_nlist_for_lower``
43+
override during tracing, this keeps the compiled graph's ``nnei`` axis
44+
fully dynamic and free of symbolic shape guards.
45+
"""
46+
pad = -torch.ones(
47+
(*nlist.shape[:2], 1),
48+
dtype=nlist.dtype,
49+
device=nlist.device,
50+
)
51+
return torch.cat([nlist, pad], dim=-1)
52+
53+
3154
def _cal_hessian_ext(
3255
model: Any,
3356
kk: str,
@@ -346,6 +369,7 @@ def fn(
346369
aparam: torch.Tensor | None,
347370
) -> dict[str, torch.Tensor]:
348371
extended_coord = extended_coord.detach().requires_grad_(True)
372+
nlist = _pad_nlist_for_export(nlist)
349373
return model.forward_common_lower(
350374
extended_coord,
351375
extended_atype,
@@ -356,13 +380,26 @@ def fn(
356380
do_atomic_virial=do_atomic_virial,
357381
)
358382

359-
return make_fx(fn, **make_fx_kwargs)(
360-
extended_coord,
361-
extended_atype,
362-
nlist,
363-
mapping,
364-
fparam,
365-
aparam,
383+
# Force `_format_nlist`'s sort branch into the compiled graph so the
384+
# exported model tolerates oversized nlists at runtime (LAMMPS builds
385+
# nlists with rcut+skin). Combined with the short-circuit order in
386+
# `_format_nlist`, no symbolic guard on the dynamic `nnei` axis is
387+
# emitted.
388+
_orig_need_sort = model.need_sorted_nlist_for_lower
389+
model.need_sorted_nlist_for_lower = types.MethodType(
390+
lambda self: True, model
366391
)
392+
try:
393+
traced = make_fx(fn, **make_fx_kwargs)(
394+
extended_coord,
395+
extended_atype,
396+
nlist,
397+
mapping,
398+
fparam,
399+
aparam,
400+
)
401+
finally:
402+
model.need_sorted_nlist_for_lower = _orig_need_sort
403+
return traced
367404

368405
return CM

deepmd/pt_expt/model/polar_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import types
23
from typing import (
34
Any,
45
)
@@ -16,6 +17,7 @@
1617
)
1718

1819
from .make_model import (
20+
_pad_nlist_for_export,
1921
make_model,
2022
)
2123
from .model import (
@@ -117,6 +119,7 @@ def fn(
117119
aparam: torch.Tensor | None,
118120
) -> dict[str, torch.Tensor]:
119121
extended_coord = extended_coord.detach().requires_grad_(True)
122+
nlist = _pad_nlist_for_export(nlist)
120123
return model.forward_lower(
121124
extended_coord,
122125
extended_atype,
@@ -127,6 +130,13 @@ def fn(
127130
do_atomic_virial=do_atomic_virial,
128131
)
129132

130-
return make_fx(fn, **make_fx_kwargs)(
131-
extended_coord, extended_atype, nlist, mapping, fparam, aparam
132-
)
133+
# See make_model.py for the rationale of the pad + monkeypatch.
134+
_orig_need_sort = model.need_sorted_nlist_for_lower
135+
model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model)
136+
try:
137+
traced = make_fx(fn, **make_fx_kwargs)(
138+
extended_coord, extended_atype, nlist, mapping, fparam, aparam
139+
)
140+
finally:
141+
model.need_sorted_nlist_for_lower = _orig_need_sort
142+
return traced

0 commit comments

Comments
 (0)