Skip to content

Commit 97e9c9e

Browse files
wanghan-iapcmHan Wang
andauthored
feat(c++,pt-expt): add .pt2 (AOTInductor) C/C++ inference with DPA1/DPA2/DPA3 support (#5298)
Add C/C++ inference support for the `.pt2` (torch.export / AOTInductor) backend, covering all major descriptor types: SE_E2_A, DPA1, DPA2, and DPA3. ### C/C++ inference backend (`DeepPotPTExpt`) - New `DeepPotPTExpt` backend that loads `.pt2` models via `torch::inductor::AOTIModelContainerRunnerCpu` - Supports PBC, NoPbc, fparam/aparam, multi-frame batching, atomic energy/virial, LAMMPS neighbor list (with ghost atoms, 2rc padding, type selection) - Registered alongside existing PT/TF/JAX/PD backends via the `.pt2` file extension ### dpmodel fixes for torch.export compatibility - Replace `[:, :nloc]` slicing with `xp_take_first_n()` in DPA1, DPA2, DPA3, and repflows/repformers — the original slicing creates `Ne(nall, nloc)` shape constraints that fail when `nall == nloc` (NoPbc case) - Replace flat `(nf*nall,)` indexing in `dpa1.py` and `exclude_mask.py` with `xp_take_along_axis` - Replace `xp.reshape(mapping, (nframes, -1, 1))` with `xp.expand_dims` in repflows/repformers — the `-1` resolves to `nall` during tracing ### pt_expt serialization - `.pt2` export via `torch.export.export` → `aot_compile` → package as zip - Python inference via `torch._inductor.aoti_load_package` ### Bug fix in all C++ backends - Fix ghost-to-local mapping when virtual atoms are present — the old code `mapping[ii] = lmp_list.mapping[fwd_map[ii]]` used post-filter indices as original indices; fixed to `mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]` - Fix use-after-free in `DeepPotPTExpt.cc` where `torch::from_blob` referenced a local vector after it went out of scope ### Test infrastructure - Model generation scripts (`gen_dpa1.py`, `gen_dpa2.py`, `gen_dpa3.py`, `gen_fparam_aparam.py`) that build from dpmodel config → serialize → export to both `.pth` and `.pt2` with identical weights - Remove pre-committed `.pth` files; regenerate in CI via `convert-models.sh` - C++ tests for all descriptor types: SE_E2_A, DPA1, DPA2, DPA3 (both `.pth` and `.pt2`, PBC + NoPbc, double + float) - Python unit tests for pt_expt inference (`test_deep_eval.py`) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for PyTorch exportable (.pt2) models and runtime detection, enabling AOTInductor-based inference across interfaces. * **Bug Fixes** * Improved neighbor/embedding extraction and broadcasting to increase backend export compatibility and robustness. * **Tests** * Added extensive C++ and Python test suites and reference-generation scripts to validate .pt2 inference paths and cross-format consistency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 034e613 commit 97e9c9e

44 files changed

Lines changed: 7560 additions & 132 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/backend/pt_expt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class PyTorchExportableBackend(Backend):
4141
| Backend.Feature.IO
4242
)
4343
"""The features of the backend."""
44-
suffixes: ClassVar[list[str]] = [".pte"]
44+
suffixes: ClassVar[list[str]] = [".pte", ".pt2"]
4545
"""The suffixes of the backend."""
4646

4747
def is_available(self) -> bool:

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from deepmd.dpmodel.array_api import (
2121
Array,
2222
xp_take_along_axis,
23+
xp_take_first_n,
2324
)
2425
from deepmd.dpmodel.common import (
2526
cast_precision,
@@ -536,7 +537,7 @@ def call(
536537
(nf, nall, self.tebd_dim),
537538
)
538539
# nfnl x tebd_dim
539-
atype_embd = atype_embd_ext[:, :nloc, :]
540+
atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc)
540541
grrg, g2, h2, rot_mat, sw = self.se_atten(
541542
nlist,
542543
coord_ext,
@@ -1086,7 +1087,7 @@ def call(
10861087
self.stddev[...],
10871088
)
10881089
nf, nloc, nnei, _ = dmatrix.shape
1089-
atype = atype_ext[:, :nloc]
1090+
atype = xp_take_first_n(atype_ext, 1, nloc)
10901091
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
10911092
# nfnl x nnei
10921093
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
@@ -1105,6 +1106,12 @@ def call(
11051106
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
11061107
ng = self.neuron[-1]
11071108
nt = self.tebd_dim
1109+
1110+
# Gather neighbor info using xp_take_along_axis along axis=1.
1111+
# This avoids flat (nf*nall,) indexing that creates Ne(nall, nloc)
1112+
# constraints in torch.export, breaking NoPbc (nall == nloc).
1113+
nlist_2d = xp.reshape(nlist_masked, (nf, nloc * nnei)) # (nf, nloc*nnei)
1114+
11081115
# nfnl x nnei x 4
11091116
rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
11101117
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
@@ -1113,15 +1120,16 @@ def call(
11131120
if self.tebd_input_mode in ["concat"]:
11141121
# nfnl x tebd_dim
11151122
atype_embd = xp.reshape(
1116-
atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)
1123+
xp_take_first_n(atype_embd_ext, 1, nloc), (nf * nloc, self.tebd_dim)
11171124
)
11181125
# nfnl x nnei x tebd_dim
11191126
atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1))
1120-
index = xp.tile(
1121-
xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)
1127+
# Gather neighbor type embeddings: (nf, nall, tebd_dim) -> (nf, nloc*nnei, tebd_dim)
1128+
nlist_idx_tebd = xp.tile(nlist_2d[:, :, xp.newaxis], (1, 1, self.tebd_dim))
1129+
atype_embd_nlist = xp_take_along_axis(
1130+
atype_embd_ext, nlist_idx_tebd, axis=1
11221131
)
11231132
# nfnl x nnei x tebd_dim
1124-
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
11251133
atype_embd_nlist = xp.reshape(
11261134
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
11271135
)
@@ -1140,10 +1148,9 @@ def call(
11401148
assert self.embeddings_strip is not None
11411149
assert type_embedding is not None
11421150
ntypes_with_padding = type_embedding.shape[0]
1143-
# nf x (nl x nnei)
1144-
nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei))
1145-
# nf x (nl x nnei)
1146-
nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1)
1151+
# Gather neighbor types: (nf, nall) -> (nf, nloc*nnei)
1152+
nei_type = xp_take_along_axis(atype_ext, nlist_2d, axis=1)
1153+
nei_type = xp.reshape(nei_type, (-1,)) # (nf * nloc * nnei,)
11471154
# (nf x nl x nnei) x ng
11481155
nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng))
11491156
if self.type_one_side:

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from deepmd.dpmodel.array_api import (
1616
Array,
1717
xp_take_along_axis,
18+
xp_take_first_n,
1819
)
1920
from deepmd.dpmodel.common import (
2021
cast_precision,
@@ -878,7 +879,7 @@ def call(
878879
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
879880
(nframes, nall, self.tebd_dim),
880881
)
881-
g1_inp = g1_ext[:, :nloc, :]
882+
g1_inp = xp_take_first_n(g1_ext, 1, nloc)
882883
g1, _, _, _, _ = self.repinit(
883884
nlist_dict[
884885
get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel())
@@ -912,9 +913,7 @@ def call(
912913
g1 = g1 + self.tebd_transform(g1_inp)
913914
# mapping g1
914915
assert mapping is not None
915-
mapping_ext = xp.tile(
916-
xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1])
917-
)
916+
mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]))
918917
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
919918
# repformer
920919
g1, g2, h2, rot_mat, sw = self.repformers(

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from deepmd.dpmodel.array_api import (
1212
Array,
13+
xp_take_first_n,
1314
)
1415
from deepmd.dpmodel.common import (
1516
cast_precision,
@@ -653,7 +654,11 @@ def call(
653654
type_embedding = self.type_embedding.call()
654655
if self.use_loc_mapping:
655656
node_ebd_ext = xp.reshape(
656-
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], (-1,)), axis=0),
657+
xp.take(
658+
type_embedding,
659+
xp.reshape(xp_take_first_n(atype_ext, 1, nloc), (-1,)),
660+
axis=0,
661+
),
657662
(nframes, nloc, self.tebd_dim),
658663
)
659664
else:
@@ -682,7 +687,7 @@ def call(
682687
sys_cs_embd = self.cs_activation_fn(self.mix_cs_mlp.call(cs_cat))
683688
node_ebd_ext = node_ebd_ext + xp.expand_dims(sys_cs_embd, axis=1)
684689

685-
node_ebd_inp = node_ebd_ext[:, :nloc, :]
690+
node_ebd_inp = xp_take_first_n(node_ebd_ext, 1, nloc)
686691
# repflows
687692
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
688693
nlist,

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from deepmd.dpmodel.array_api import (
1414
Array,
1515
xp_take_along_axis,
16+
xp_take_first_n,
1617
)
1718
from deepmd.dpmodel.common import (
1819
to_numpy_array,
@@ -562,7 +563,7 @@ def call(
562563

563564
# get node embedding
564565
# nb x nloc x tebd_dim
565-
atype_embd = atype_embd_ext[:, :nloc, :]
566+
atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc)
566567
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
567568

568569
node_ebd = self.act(atype_embd)
@@ -641,7 +642,7 @@ def call(
641642
angle_ebd = self.angle_embd(angle_input)
642643

643644
# nb x nall x n_dim
644-
mapping = xp.tile(xp.reshape(mapping, (nframes, -1, 1)), (1, 1, self.n_dim))
645+
mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim))
645646
for idx, ll in enumerate(self.layers):
646647
# node_ebd: nb x nloc x n_dim
647648
# node_ebd_ext: nb x nall x n_dim
@@ -1421,7 +1422,7 @@ def call(
14211422
n_edge = (
14221423
int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0
14231424
)
1424-
node_ebd = node_ebd_ext[:, :nloc, :]
1425+
node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc)
14251426
assert (nb, nloc) == node_ebd.shape[:2]
14261427
if not self.use_dynamic_sel:
14271428
assert (nb, nloc, nnei) == h2.shape[:3]

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from deepmd.dpmodel.array_api import (
1717
Array,
1818
xp_take_along_axis,
19+
xp_take_first_n,
1920
)
2021
from deepmd.dpmodel.common import (
2122
to_numpy_array,
@@ -499,7 +500,7 @@ def call(
499500
sw = xp.reshape(sw, (nf, nloc, nnei))
500501
sw = xp.where(nlist_mask, sw, xp.zeros_like(sw))
501502
# nf x nloc x tebd_dim
502-
atype_embd = atype_embd_ext[:, :nloc, :]
503+
atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc)
503504
assert list(atype_embd.shape) == [nf, nloc, self.g1_dim]
504505

505506
g1 = self.act(atype_embd)
@@ -516,7 +517,7 @@ def call(
516517
# if a neighbor is real or not is indicated by nlist_mask
517518
nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist)
518519
# nf x nall x ng1
519-
mapping = xp.tile(xp.reshape(mapping, (nf, -1, 1)), (1, 1, self.g1_dim))
520+
mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.g1_dim))
520521
for idx, ll in enumerate(self.layers):
521522
# g1: nf x nloc x ng1
522523
# g1_ext: nf x nall x ng1
@@ -1765,9 +1766,8 @@ def call(
17651766
)
17661767

17671768
nf, nloc, nnei, _ = g2.shape
1768-
nall = g1_ext.shape[1]
17691769
# g1, _ = xp.split(g1_ext, [nloc], axis=1)
1770-
g1 = g1_ext[:, :nloc, :]
1770+
g1 = xp_take_first_n(g1_ext, 1, nloc)
17711771
assert (nf, nloc) == g1.shape[:2]
17721772
assert (nf, nloc, nnei) == h2.shape[:3]
17731773

deepmd/dpmodel/model/make_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from deepmd.dpmodel.array_api import (
1313
Array,
14+
xp_take_along_axis,
15+
xp_take_first_n,
1416
)
1517
from deepmd.dpmodel.atomic_model.base_atomic_model import (
1618
BaseAtomicModel,
@@ -589,7 +591,6 @@ def _format_nlist(
589591
xp = array_api_compat.array_namespace(extended_coord, nlist)
590592
n_nf, n_nloc, n_nnei = nlist.shape
591593
extended_coord = extended_coord.reshape([n_nf, -1, 3])
592-
nall = extended_coord.shape[1]
593594
rcut = self.get_rcut()
594595

595596
if n_nnei < nnei:
@@ -612,14 +613,14 @@ def _format_nlist(
612613
# make a copy before revise
613614
m_real_nei = nlist >= 0
614615
ret = xp.where(m_real_nei, nlist, 0)
615-
coord0 = extended_coord[:, :n_nloc, :]
616+
coord0 = xp_take_first_n(extended_coord, 1, n_nloc)
616617
index = xp.tile(ret.reshape(n_nf, n_nloc * n_nnei, 1), (1, 1, 3))
617-
coord1 = xp.take_along_axis(extended_coord, index, axis=1)
618+
coord1 = xp_take_along_axis(extended_coord, index, axis=1)
618619
coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3)
619620
rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
620621
rr = xp.where(m_real_nei, rr, float("inf"))
621622
rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1)
622-
ret = xp.take_along_axis(ret, ret_mapping, axis=2)
623+
ret = xp_take_along_axis(ret, ret_mapping, axis=2)
623624
ret = xp.where(rr > rcut, -1, ret)
624625
ret = ret[..., :nnei]
625626
# not extra_nlist_sort and n_nnei <= nnei:

deepmd/dpmodel/utils/exclude_mask.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepmd.dpmodel.array_api import (
77
Array,
88
xp_take_along_axis,
9+
xp_take_first_n,
910
)
1011

1112

@@ -131,18 +132,22 @@ def build_type_exclude_mask(
131132
],
132133
axis=-1,
133134
)
134-
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
135-
# nf x nloc x nnei
136-
index = xp.reshape(
137-
xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei)
135+
type_i = xp.reshape(xp_take_first_n(atype_ext, 1, nloc), (nf, nloc)) * (
136+
self.ntypes + 1
138137
)
139-
type_j = xp_take_along_axis(ae, index, axis=1)
138+
# Map -1 entries to nall (the virtual atom index in ae)
139+
nlist_for_type = xp.where(nlist == -1, xp.full_like(nlist, nall), nlist)
140+
# Gather neighbor types using xp_take_along_axis along axis=1.
141+
# This avoids flat (nf*(nall+1),) indexing that creates Ne(nall, nloc)
142+
# constraints in torch.export, breaking NoPbc (nall == nloc).
143+
nlist_for_gather = xp.reshape(nlist_for_type, (nf, nloc * nnei))
144+
type_j = xp_take_along_axis(ae, nlist_for_gather, axis=1)
140145
type_j = xp.reshape(type_j, (nf, nloc, nnei))
141146
type_ij = type_i[:, :, None] + type_j
142-
# nf x (nloc x nnei)
143-
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
147+
# (nf * nloc * nnei,)
148+
type_ij_flat = xp.reshape(type_ij, (-1,))
144149
mask = xp.reshape(
145-
xp.take(self.type_mask[...], xp.reshape(type_ij, (-1,))),
150+
xp.take(self.type_mask[...], type_ij_flat),
146151
(nf, nloc, nnei),
147152
)
148153
return mask

deepmd/dpmodel/utils/nlist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from deepmd.dpmodel.array_api import (
66
Array,
77
xp_take_along_axis,
8+
xp_take_first_n,
89
)
910

1011
from .region import (
@@ -243,8 +244,7 @@ def build_multiple_neighbor_list(
243244
nlist = xp.concat([nlist, pad], axis=-1)
244245
nsel = nsels[-1]
245246
coord1 = xp.reshape(coord, (nb, -1, 3))
246-
nall = coord1.shape[1]
247-
coord0 = coord1[:, :nloc, :]
247+
coord0 = xp_take_first_n(coord1, 1, nloc)
248248
nlist_mask = nlist == -1
249249
tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist)
250250
index = xp.tile(xp.reshape(tnlist_0, (nb, nloc * nsel, 1)), (1, 1, 3))

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@
7373
extend_descrpt_stat,
7474
)
7575

76+
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t_tebd"):
77+
78+
def tabulate_fusion_se_t_tebd(
79+
argument0: torch.Tensor,
80+
argument1: torch.Tensor,
81+
argument2: torch.Tensor,
82+
argument3: torch.Tensor,
83+
argument4: int,
84+
) -> list[torch.Tensor]:
85+
raise NotImplementedError(
86+
"tabulate_fusion_se_t_tebd is not available since customized PyTorch OP library is not built when freezing the model. "
87+
"See documentation for model compression for details."
88+
)
89+
90+
# Note: this hack cannot actually save a model that can be run using LAMMPS.
91+
torch.ops.deepmd.tabulate_fusion_se_t_tebd = tabulate_fusion_se_t_tebd
92+
7693

7794
@BaseDescriptor.register("se_e3_tebd")
7895
class DescrptSeTTebd(BaseDescriptor, torch.nn.Module):

0 commit comments

Comments
 (0)