Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError
from deepmd.pt_expt.infer.deep_eval import (
DeepEval,
)

return DeepEval

@property
def neighbor_stat(self) -> type["NeighborStat"]:
Expand All @@ -87,7 +91,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError
from deepmd.pt_expt.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
Expand All @@ -98,7 +106,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError
from deepmd.pt_expt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
Expand All @@ -109,4 +121,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError
from deepmd.pt_expt.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
27 changes: 27 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
# torch.take_along_dim requires int64 indices
if array_api_compat.is_torch_array(indices):
indices = xp.astype(indices, xp.int64)
if array_api_compat.is_torch_array(arr):
# Use torch.gather directly for torch.export dynamic shape compatibility.
# array_api_compat's take_along_axis / torch.take_along_dim specializes
# the source dimension size to a constant during torch.export tracing,
# breaking dynamic shape export. torch.gather is the underlying
# primitive and handles symbolic shapes correctly.
import torch

return torch.gather(arr, axis, indices)
if Version(xp.__array_api_version__) >= Version("2024.12"):
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
return xp.take_along_axis(arr, indices, axis=axis)
Expand Down Expand Up @@ -62,6 +71,24 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
return xp_swapaxes(out, axis, -1)


def xp_take_first_n(arr: Array, dim: int, n: int) -> Array:
"""Take the first *n* elements along *dim*.

For torch tensors, uses ``torch.index_select`` so that
``torch.export`` does not emit a contiguity guard that would
prevent the ``nall == nloc`` (no-PBC) case from working.
For numpy / jax, uses regular slicing.
"""
if array_api_compat.is_torch_array(arr):
import torch

indices = torch.arange(n, dtype=torch.int64, device=arr.device)
return torch.index_select(arr, dim, indices)
slices = [slice(None)] * arr.ndim
slices[dim] = slice(0, n)
return arr[tuple(slices)]


def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
"""Reduces all values from the src tensor to the indices specified in the index tensor.

Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from deepmd.dpmodel.array_api import (
Array,
xp_take_first_n,
)
from deepmd.dpmodel.common import (
NativeOP,
Expand Down Expand Up @@ -250,7 +251,7 @@ def forward_common_atomic(
"""
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
atype = xp_take_first_n(extended_atype, 1, nloc)
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
Expand All @@ -268,7 +269,7 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc]
atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc)
if self.atom_excl is not None:
atom_mask = xp.logical_and(
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deepmd.dpmodel.array_api import (
Array,
xp_take_first_n,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
Expand Down Expand Up @@ -178,7 +179,7 @@ def forward_atomic(

"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
atype = xp_take_first_n(extended_atype, 1, nloc)
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)
from deepmd.dpmodel.utils.safe_gradient import (
safe_for_vector_norm,
Expand Down Expand Up @@ -76,7 +77,7 @@ def _make_env_mat(
# nf x nloc x nnei x 3
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
# nf x nloc x 1 x 3
coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3))
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3))
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
Expand Down Expand Up @@ -153,7 +154,7 @@ def call(
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
em, diff, sw = self._call(nlist, coord_ext, radial_only)
nf, nloc, nnei = nlist.shape
atype = atype_ext[:, :nloc]
atype = xp_take_first_n(atype_ext, 1, nloc)
if davg is not None:
em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape)
if dstd is not None:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt_expt/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading