Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _pair_tabulated_inter(

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, nspline
i_type, j_type, idx, self.tab_data[...], nspline
)
table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4))
ener = self._calculate_ener(table_coef, uu)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,11 @@ def call(
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
coord_ext,
atype_ext,
nlist,
self.mean[...],
self.stddev[...],
)
nf, nloc, nnei, _ = dmatrix.shape
atype = atype_ext[:, :nloc]
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,11 @@ def call(
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
# nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
dmatrix, diff, sw = self.env_mat_edge.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
coord_ext,
atype_ext,
nlist,
self.mean[...],
self.stddev[...],
)
# nb x nloc x nnei
nlist_mask = nlist != -1
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,11 @@ def call(
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
coord_ext,
atype_ext,
nlist,
self.mean[...],
self.stddev[...],
)
nf, nloc, nnei, _ = dmatrix.shape
# nf x nloc x nnei
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,11 @@ def call(
input_dtype = coord_ext.dtype
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
coord_ext,
atype_ext,
nlist,
self.davg[...],
self.dstd[...],
)
nf, nloc, nnei, _ = rr.shape
sec = self.sel_cumsum
Expand Down
7 changes: 6 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,12 @@ def call(
del mapping
# nf x nloc x nnei x 1
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd, True
coord_ext,
atype_ext,
nlist,
self.davg[...],
self.dstd[...],
True,
)
nf, nloc, nnei, _ = rr.shape
sec = self.sel_cumsum
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def call(
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
coord_ext,
atype_ext,
nlist,
self.davg[...],
self.dstd[...],
)
nf, nloc, nnei, _ = rr.shape
sec = self.sel_cumsum
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,11 @@ def call(
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
coord_ext,
atype_ext,
nlist,
self.mean[...],
self.stddev[...],
)
nf, nloc, nnei, _ = dmatrix.shape
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
Expand Down
12 changes: 8 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _call_common(
f"get an input fparam of dim {fparam.shape[-1]}, "
f"which is not consistent with {self.numb_fparam}."
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
fparam = xp.tile(
xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1)
)
Expand All @@ -432,7 +432,7 @@ def _call_common(
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
xx = xp.concat(
[xx, aparam],
axis=-1,
Expand All @@ -445,7 +445,9 @@ def _call_common(

if self.dim_case_embd > 0:
assert self.case_embd is not None
case_embd = xp.tile(xp.reshape(self.case_embd, [1, 1, -1]), [nf, nloc, 1])
case_embd = xp.tile(
xp.reshape(self.case_embd[...], [1, 1, -1]), [nf, nloc, 1]
)
xx = xp.concat(
[xx, case_embd],
axis=-1,
Expand Down Expand Up @@ -482,7 +484,9 @@ def _call_common(
outs -= self.nets[()](xx_zeros)
outs += xp.reshape(
xp.take(
xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0
xp.astype(self.bias_atom_e[...], outs.dtype),
xp.reshape(atype, [-1]),
axis=0,
),
[nf, nloc, net_dim_out],
)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def build_type_exclude_mask(
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return xp.reshape(
xp.take(self.type_mask, xp.reshape(atype, [-1]), axis=0), (nf, natom)
xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0),
(nf, natom),
)


Expand Down Expand Up @@ -131,7 +132,8 @@ def build_type_exclude_mask(
# nf x (nloc x nnei)
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
mask = xp.reshape(
xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei)
xp.take(self.type_mask[...], xp.reshape(type_ij, (-1,))),
(nf, nloc, nnei),
)
return mask

Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def call(self, x: np.ndarray) -> np.ndarray:
xp = array_api_compat.array_namespace(x)
fn = get_activation_fn(self.activation_function)
y = (
xp.matmul(x, self.w) + self.b
xp.matmul(x, self.w[...]) + self.b[...]
if self.b is not None
else xp.matmul(x, self.w)
else xp.matmul(x, self.w[...])
)
if y.dtype != x.dtype:
# workaround for bfloat16
Expand Down