Skip to content

Commit a31fd18

Browse files
author
Han Wang
committed
feat(pt_expt): implement .pte inference pipeline with dynamic shapes
Implement the full pt_expt inference pipeline: serialize models to .pte files via torch.export, and load them for inference via DeepPot/DeepEval. Key changes: - Add DeepEval backend for .pte files (deepmd/pt_expt/infer/deep_eval.py) - Add serialize/deserialize hooks (deepmd/pt_expt/utils/serialization.py) - Wire up backend hooks in deepmd/backend/pt_expt.py - Add forward_common_lower_exportable using make_fx + torch.export - Support dynamic nframes, nloc, and nall dimensions - Fix atomic_virial_corr to use explicit loop instead of vmap Add xp_take_first_n helper to avoid torch.export contiguity guards on [:, :nloc] slices. When torch.export traces tensor[:, :nloc] on a tensor of size nall, it records a Ne(nall, nloc) guard from the view's contiguity check, which fails when nall == nloc (no PBC). Using torch.index_select instead creates a new tensor, avoiding the guard.
1 parent a3db25a commit a31fd18

13 files changed

Lines changed: 1226 additions & 34 deletions

File tree

deepmd/backend/pt_expt.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
7676
type[DeepEvalBackend]
7777
The Deep Eval backend of the backend.
7878
"""
79-
raise NotImplementedError
79+
from deepmd.pt_expt.infer.deep_eval import (
80+
DeepEval,
81+
)
82+
83+
return DeepEval
8084

8185
@property
8286
def neighbor_stat(self) -> type["NeighborStat"]:
@@ -87,7 +91,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
8791
type[NeighborStat]
8892
The neighbor statistics of the backend.
8993
"""
90-
raise NotImplementedError
94+
from deepmd.dpmodel.utils.neighbor_stat import (
95+
NeighborStat,
96+
)
97+
98+
return NeighborStat
9199

92100
@property
93101
def serialize_hook(self) -> Callable[[str], dict]:
@@ -98,7 +106,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
98106
Callable[[str], dict]
99107
The serialize hook of the backend.
100108
"""
101-
raise NotImplementedError
109+
from deepmd.pt_expt.utils.serialization import (
110+
serialize_from_file,
111+
)
112+
113+
return serialize_from_file
102114

103115
@property
104116
def deserialize_hook(self) -> Callable[[str, dict], None]:
@@ -109,4 +121,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
109121
Callable[[str, dict], None]
110122
The deserialize hook of the backend.
111123
"""
112-
raise NotImplementedError
124+
from deepmd.pt_expt.utils.serialization import (
125+
deserialize_to_file,
126+
)
127+
128+
return deserialize_to_file

deepmd/dpmodel/array_api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
3232
# torch.take_along_dim requires int64 indices
3333
if array_api_compat.is_torch_array(indices):
3434
indices = xp.astype(indices, xp.int64)
35+
if array_api_compat.is_torch_array(arr):
36+
# Use torch.gather directly for torch.export dynamic shape compatibility.
37+
# array_api_compat's take_along_axis / torch.take_along_dim specializes
38+
# the source dimension size to a constant during torch.export tracing,
39+
# breaking dynamic shape export. torch.gather is the underlying
40+
# primitive and handles symbolic shapes correctly.
41+
import torch
42+
43+
return torch.gather(arr, axis, indices)
3544
if Version(xp.__array_api_version__) >= Version("2024.12"):
3645
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
3746
return xp.take_along_axis(arr, indices, axis=axis)
@@ -62,6 +71,24 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
6271
return xp_swapaxes(out, axis, -1)
6372

6473

74+
def xp_take_first_n(arr: Array, dim: int, n: int) -> Array:
75+
"""Take the first *n* elements along *dim*.
76+
77+
For torch tensors, uses ``torch.index_select`` so that
78+
``torch.export`` does not emit a contiguity guard that would
79+
prevent the ``nall == nloc`` (no-PBC) case from working.
80+
For numpy / jax, uses regular slicing.
81+
"""
82+
if array_api_compat.is_torch_array(arr):
83+
import torch
84+
85+
indices = torch.arange(n, dtype=torch.int64, device=arr.device)
86+
return torch.index_select(arr, dim, indices)
87+
slices = [slice(None)] * arr.ndim
88+
slices[dim] = slice(0, n)
89+
return arr[tuple(slices)]
90+
91+
6592
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
6693
"""Reduces all values from the src tensor to the indices specified in the index tensor.
6794

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from deepmd.dpmodel.array_api import (
1515
Array,
16+
xp_take_first_n,
1617
)
1718
from deepmd.dpmodel.common import (
1819
NativeOP,
@@ -211,7 +212,7 @@ def forward_common_atomic(
211212
"""
212213
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
213214
_, nloc, _ = nlist.shape
214-
atype = extended_atype[:, :nloc]
215+
atype = xp_take_first_n(extended_atype, 1, nloc)
215216
if self.pair_excl is not None:
216217
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
217218
# exclude neighbors in the nlist
@@ -229,7 +230,7 @@ def forward_common_atomic(
229230
ret_dict = self.apply_out_stat(ret_dict, atype)
230231

231232
# nf x nloc
232-
atom_mask = ext_atom_mask[:, :nloc]
233+
atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc)
233234
if self.atom_excl is not None:
234235
atom_mask = xp.logical_and(
235236
atom_mask, self.atom_excl.build_type_exclude_mask(atype)

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from deepmd.dpmodel.array_api import (
1010
Array,
11+
xp_take_first_n,
1112
)
1213
from deepmd.dpmodel.descriptor.base_descriptor import (
1314
BaseDescriptor,
@@ -178,7 +179,7 @@ def forward_atomic(
178179
179180
"""
180181
nframes, nloc, nnei = nlist.shape
181-
atype = extended_atype[:, :nloc]
182+
atype = xp_take_first_n(extended_atype, 1, nloc)
182183
descriptor, rot_mat, g2, h2, sw = self.descriptor(
183184
extended_coord,
184185
extended_atype,

deepmd/dpmodel/utils/env_mat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from deepmd.dpmodel.array_api import (
1212
Array,
1313
xp_take_along_axis,
14+
xp_take_first_n,
1415
)
1516
from deepmd.dpmodel.utils.safe_gradient import (
1617
safe_for_vector_norm,
@@ -72,7 +73,7 @@ def _make_env_mat(
7273
# nf x nloc x nnei x 3
7374
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
7475
# nf x nloc x 1 x 3
75-
coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3))
76+
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3))
7677
# nf x nloc x nnei x 3
7778
diff = coord_r - coord_l
7879
# nf x nloc x nnei
@@ -149,7 +150,7 @@ def call(
149150
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
150151
em, diff, sw = self._call(nlist, coord_ext, radial_only)
151152
nf, nloc, nnei = nlist.shape
152-
atype = atype_ext[:, :nloc]
153+
atype = xp_take_first_n(atype_ext, 1, nloc)
153154
if davg is not None:
154155
em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape)
155156
if dstd is not None:

deepmd/pt_expt/infer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

0 commit comments

Comments
 (0)