Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def xp_swapaxes(a: Array, axis1: int, axis2: int) -> Array:

def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
xp = array_api_compat.array_namespace(arr)
# torch.take_along_dim requires int64 indices
if array_api_compat.is_torch_array(indices):
indices = xp.astype(indices, xp.int64)
if Version(xp.__array_api_version__) >= Version("2024.12"):
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
return xp.take_along_axis(arr, indices, axis=axis)
Expand All @@ -48,7 +51,10 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
else:
indices = xp.reshape(indices, (0, 0))

offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis]
dev = array_api_compat.device(indices)
offset = (xp.arange(indices.shape[0], dtype=indices.dtype, device=dev) * m)[
:, xp.newaxis
]
indices = xp.reshape(offset + indices, (-1,))

out = xp.take(arr, indices)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,11 @@ def _compute_weight(
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_)
nmodels = len(self.models)
nframes, nloc, _ = nlists_[0].shape
dev = array_api_compat.device(extended_coord)
# the dtype of weights is the interface data type.
return [
xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION, device=dev)
/ nmodels
for _ in range(nmodels)
]

Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ def forward_atomic(
) # (nframes, nloc, nnei)

# (nframes, nloc, nnei), index type is int64.
dev = array_api_compat.device(extended_atype)
j_type = extended_atype[
xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None],
xp.arange(extended_atype.shape[0], dtype=xp.int64, device=dev)[
:, None, None
],
masked_nlist,
]

Expand Down Expand Up @@ -327,8 +330,11 @@ def _get_pairwise_dist(coords: Array, nlist: Array) -> Array:
The pairwise distance between the atoms (nframes, nloc, nnei).
"""
xp = array_api_compat.array_namespace(coords, nlist)
dev = array_api_compat.device(nlist)
# index type is int64
batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None]
batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64, device=dev)[
:, None, None
]
neighbor_atoms = coords[batch_indices, nlist]
loc_atoms = coords[:, : nlist.shape[1], :]
pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
self.smooth = smooth
self.exclude_types = exclude_types
self.env_protection = env_protection
self.rcut_smth = self.repinit.get_rcut_smth()
self.trainable = trainable
self.add_tebd_to_repinit_out = add_tebd_to_repinit_out

Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,12 @@ def call(
# n_angle x 1
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
edge_index = xp.zeros([2, 1], dtype=nlist.dtype)
angle_index = xp.zeros([3, 1], dtype=nlist.dtype)
edge_index = xp.zeros(
[2, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist)
)
angle_index = xp.zeros(
[3, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist)
)

# get edge and angle embedding
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
Expand Down Expand Up @@ -1711,6 +1715,7 @@ def call(
xp.zeros(
(nb, nloc, self.nnei - self.a_sel, self.e_dim),
dtype=edge_ebd.dtype,
device=array_api_compat.device(edge_ebd),
),
],
axis=2,
Expand Down Expand Up @@ -1741,6 +1746,7 @@ def call(
xp.zeros(
(nb, nloc, self.nnei - self.a_sel),
dtype=a_nlist_mask.dtype,
device=array_api_compat.device(a_nlist_mask),
),
],
axis=-1,
Expand Down
12 changes: 8 additions & 4 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def call(
g1 = self.act(atype_embd)
# nf x nloc x nnei x 1, nf x nloc x nnei x 3
if not self.direct_dist:
g2, h2 = xp.split(dmatrix, [1], axis=-1)
g2, h2 = dmatrix[..., :1], dmatrix[..., 1:]
else:
g2, h2 = safe_for_vector_norm(diff, axis=-1, keepdims=True), diff
g2 = g2 / self.rcut
Expand Down Expand Up @@ -756,10 +756,12 @@ def _cal_hg(
else:
g = _apply_switch(g, sw)
if not use_sqrt_nnei:
invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1, 1), dtype=g.dtype)
invnnei = (1.0 / float(nnei)) * xp.ones(
(nf, nloc, 1, 1), dtype=g.dtype, device=array_api_compat.device(g)
)
else:
invnnei = (1.0 / (float(nnei) ** 0.5)) * xp.ones(
(nf, nloc, 1, 1), dtype=g.dtype
(nf, nloc, 1, 1), dtype=g.dtype, device=array_api_compat.device(g)
)
# nf x nloc x 3 x ng
hg = xp.matmul(xp.matrix_transpose(h), g) * invnnei
Expand Down Expand Up @@ -1655,7 +1657,9 @@ def _update_g1_conv(
invnnei = invnnei[:, :, xp.newaxis]
else:
gg1 = _apply_switch(gg1, sw)
invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1), dtype=gg1.dtype)
invnnei = (1.0 / float(nnei)) * xp.ones(
(nf, nloc, 1), dtype=gg1.dtype, device=array_api_compat.device(gg1)
)
if not self.g1_out_conv:
# nf x nloc x ng2
g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei
Expand Down
181 changes: 63 additions & 118 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ class DescrptSeA(NativeOP, BaseDescriptor):
-----------
The currently implementation does not support the following features

1. type_one_side == False
2. exclude_types != []
3. spin is not None
1. exclude_types != []
2. spin is not None

References
----------
Expand Down Expand Up @@ -427,45 +426,81 @@ def call(
The smooth switch function.
"""
del mapping
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
input_dtype = coord_ext.dtype
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
coord_ext,
atype_ext,
nlist,
self.davg[...],
self.dstd[...],
)
nf, nloc, nnei, _ = rr.shape
sec = np.append([0], np.cumsum(self.sel))
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = np.zeros([nf * nloc, ng, 4], dtype=PRECISION_DICT[self.precision])
gr = xp.zeros(
[nf * nloc, ng, 4],
dtype=input_dtype,
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
rr = rr.reshape(nf * nloc, nnei, 4)
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, self.dstd.dtype)

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
if self.type_one_side:
(tt,) = embedding_idx
ti_mask = np.s_[:]
else:
ti, tt = embedding_idx
ti_mask = atype_ext[:, :nloc].ravel() == ti
mm = exclude_mask[ti_mask, sec[tt] : sec[tt + 1]]
tr = rr[ti_mask, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, None]
ss = tr[..., 0:1]
gg = self.cal_g(ss, embedding_idx)
gr_tmp = np.einsum("lni,lnj->lij", gg, tr)
gr[ti_mask] += gr_tmp
gr = gr.reshape(nf, nloc, ng, 4)
if self.type_one_side:
for tt in range(self.ntypes):
mm = exclude_mask[:, sec[tt] : sec[tt + 1]]
tr = rr[:, sec[tt] : sec[tt + 1], :]
tr = tr * xp.astype(mm[:, :, None], tr.dtype)
ss = tr[..., 0:1]
gg = self.cal_g(ss, (tt,))
gr += xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
else:
# Sort atoms by center type so each type forms a contiguous block.
# Slice indexing (arr[s:e]) is array-api compatible and lets us
# run cal_g only on atoms of the matching center type, keeping the
# same O(nf*nloc) total embedding cost as the original numpy code.
atype_loc = xp.reshape(atype_ext[:, :nloc], (nf * nloc,))
sort_idx = xp.argsort(atype_loc)
unsort_idx = xp.argsort(sort_idx)
rr_s = xp.take(rr, sort_idx, axis=0)
mask_s = xp.take(exclude_mask, sort_idx, axis=0)
dev = array_api_compat.device(coord_ext)
gr_s = xp.zeros([nf * nloc, ng, 4], dtype=input_dtype, device=dev)
# Per-type boundaries in sorted order
type_ends = []
offset = 0
for ti in range(self.ntypes):
offset += int(xp.sum(xp.astype(atype_loc == ti, xp.int32)))
type_ends.append(offset)
type_starts = [0, *type_ends[:-1]]
for ti in range(self.ntypes):
s, e = type_starts[ti], type_ends[ti]
if s == e:
continue
for tt in range(self.ntypes):
mm = mask_s[s:e, sec[tt] : sec[tt + 1]]
tr = rr_s[s:e, sec[tt] : sec[tt + 1], :]
tr = tr * xp.astype(mm[:, :, None], tr.dtype)
ss = tr[..., 0:1]
gg = self.cal_g(ss, (ti, tt))
gr_s[s:e] = gr_s[s:e] + xp.sum(
gg[:, :, :, None] * tr[:, :, None, :], axis=1
)
gr = xp.take(gr_s, unsort_idx, axis=0)
Comment thread
wanghan-iapcm marked this conversation as resolved.
gr = xp.reshape(gr, (nf, nloc, ng, 4))
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -553,94 +588,4 @@ def update_sel(
return local_jdata_cpy, min_nbor_dist


class DescrptSeAArrayAPI(DescrptSeA):
@cast_precision
def call(
self,
coord_ext: Array,
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
) -> Array:
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping from extended to local region. not used by this descriptor.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
if not self.type_one_side:
raise NotImplementedError(
"type_one_side == False is not supported in DescrptSeAArrayAPI"
)
del mapping
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
input_dtype = coord_ext.dtype
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext,
atype_ext,
nlist,
self.davg[...],
self.dstd[...],
)
nf, nloc, nnei, _ = rr.shape
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros(
[nf * nloc, ng, 4],
dtype=input_dtype,
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, self.dstd.dtype)

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
(tt,) = embedding_idx
mm = exclude_mask[:, sec[tt] : sec[tt + 1]]
tr = rr[:, sec[tt] : sec[tt + 1], :]
tr = tr * xp.astype(mm[:, :, None], tr.dtype)
ss = tr[..., 0:1]
gg = self.cal_g(ss, embedding_idx)
# gr_tmp = xp.einsum("lni,lnj->lij", gg, tr)
gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
gr += gr_tmp
gr = xp.reshape(gr, (nf, nloc, ng, 4))
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
# nf x nloc x ng x ng1
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
return grrg, gr[..., 1:], None, None, ww
DescrptSeAArrayAPI = DescrptSeA
7 changes: 6 additions & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,12 @@ def _format_nlist(
ret = xp.concat(
[
nlist,
-1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype),
-1
* xp.ones(
[n_nf, n_nloc, nnei - n_nnei],
dtype=nlist.dtype,
device=array_api_compat.device(nlist),
),
],
axis=-1,
)
Expand Down
Loading
Loading