Skip to content

Commit f513b7b

Browse files
committed
fix jax reshape list
1 parent 005783f commit f513b7b

12 files changed

Lines changed: 32 additions & 32 deletions

File tree

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def call(
520520
type_embedding = self.type_embedding.call()
521521
# nf x nall x tebd_dim
522522
atype_embd_ext = xp.reshape(
523-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
523+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
524524
(nf, nall, self.tebd_dim),
525525
)
526526
# nfnl x tebd_dim

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def call(
841841
type_embedding = self.type_embedding.call()
842842
# repinit
843843
g1_ext = xp.reshape(
844-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
844+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
845845
(nframes, nall, self.tebd_dim),
846846
)
847847
g1_inp = g1_ext[:, :nloc, :]

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,12 +562,12 @@ def call(
562562
type_embedding = self.type_embedding.call()
563563
if self.use_loc_mapping:
564564
node_ebd_ext = xp.reshape(
565-
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], [-1]), axis=0),
565+
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], (-1,)), axis=0),
566566
(nframes, nloc, self.tebd_dim),
567567
)
568568
else:
569569
node_ebd_ext = xp.reshape(
570-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
570+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
571571
(nframes, nall, self.tebd_dim),
572572
)
573573
node_ebd_inp = node_ebd_ext[:, :nloc, :]

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def call(
358358
type_embedding = self.type_embedding.call()
359359
# nf x nall x tebd_dim
360360
atype_embd_ext = xp.reshape(
361-
xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0),
361+
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
362362
(nf, nall, self.tebd_dim),
363363
)
364364
# nfnl x tebd_dim

deepmd/dpmodel/fitting/polarizability_fitting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def call(
289289
]
290290
# out = out * self.scale[atype, ...]
291291
scale_atype = xp.reshape(
292-
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0),
292+
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0),
293293
(*atype.shape, 1),
294294
)
295295
out = out * scale_atype
@@ -315,7 +315,7 @@ def call(
315315
bias = xp.reshape(
316316
xp.take(
317317
xp.astype(self.constant_matrix, out.dtype),
318-
xp.reshape(atype, [-1]),
318+
xp.reshape(atype, (-1,)),
319319
axis=0,
320320
),
321321
(nframes, nloc),

deepmd/dpmodel/loss/ener.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,18 @@ def call(
132132
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener))
133133
energy = xp.sum(atom_ener_coeff * atom_ener, 1)
134134
if self.has_f or self.has_pf or self.relative_f or self.has_gf:
135-
force_reshape = xp.reshape(force, [-1])
136-
force_hat_reshape = xp.reshape(force_hat, [-1])
135+
force_reshape = xp.reshape(force, (-1,))
136+
force_hat_reshape = xp.reshape(force_hat, (-1,))
137137
diff_f = force_hat_reshape - force_reshape
138138
else:
139139
diff_f = None
140140

141141
if self.relative_f is not None:
142-
force_hat_3 = xp.reshape(force_hat, [-1, 3])
143-
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f
144-
diff_f_3 = xp.reshape(diff_f, [-1, 3])
142+
force_hat_3 = xp.reshape(force_hat, (-1, 3))
143+
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), (-1, 1)) + self.relative_f
144+
diff_f_3 = xp.reshape(diff_f, (-1, 3))
145145
diff_f_3 = diff_f_3 / norm_f
146-
diff_f = xp.reshape(diff_f_3, [-1])
146+
diff_f = xp.reshape(diff_f_3, (-1,))
147147

148148
atom_norm = 1.0 / natoms
149149
atom_norm_ener = 1.0 / natoms
@@ -184,15 +184,15 @@ def call(
184184
loss += pref_f * l2_force_loss
185185
else:
186186
l_huber_loss = custom_huber_loss(
187-
xp.reshape(force, [-1]),
188-
xp.reshape(force_hat, [-1]),
187+
xp.reshape(force, (-1,)),
188+
xp.reshape(force_hat, (-1,)),
189189
delta=self.huber_delta,
190190
)
191191
loss += pref_f * l_huber_loss
192192
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
193193
if self.has_v:
194-
virial_reshape = xp.reshape(virial, [-1])
195-
virial_hat_reshape = xp.reshape(virial_hat, [-1])
194+
virial_reshape = xp.reshape(virial, (-1,))
195+
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
196196
l2_virial_loss = xp.mean(
197197
xp.square(virial_hat_reshape - virial_reshape),
198198
)
@@ -207,8 +207,8 @@ def call(
207207
loss += pref_v * l_huber_loss
208208
more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial)
209209
if self.has_ae:
210-
atom_ener_reshape = xp.reshape(atom_ener, [-1])
211-
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
210+
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
211+
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
212212
l2_atom_ener_loss = xp.mean(
213213
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
214214
)
@@ -225,7 +225,7 @@ def call(
225225
l2_atom_ener_loss, find_atom_ener
226226
)
227227
if self.has_pf:
228-
atom_pref_reshape = xp.reshape(atom_pref, [-1])
228+
atom_pref_reshape = xp.reshape(atom_pref, (-1,))
229229
l2_pref_force_loss = xp.mean(
230230
xp.multiply(xp.square(diff_f), atom_pref_reshape),
231231
)
@@ -236,10 +236,10 @@ def call(
236236
if self.has_gf:
237237
find_drdq = label_dict["find_drdq"]
238238
drdq = label_dict["drdq"]
239-
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
240-
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
239+
force_reshape_nframes = xp.reshape(force, (-1, natoms[0] * 3))
240+
force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms[0] * 3))
241241
drdq_reshape = xp.reshape(
242-
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
242+
drdq, (-1, natoms[0] * 3, self.numb_generalized_coord)
243243
)
244244
gen_force_hat = xp.einsum(
245245
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes

deepmd/dpmodel/model/transform_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def communicate_extended_output(
100100
if vdef.r_differentiable:
101101
if model_ret[kk_derv_r] is not None:
102102
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
103-
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
103+
mapping = xp.reshape(mapping, tuple(mldims + [1] * len(derv_r_ext_dims)))
104104
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
105105
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
106106
force = xp_scatter_sum(

deepmd/dpmodel/utils/env_mat_stat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def iter(
189189
for type_i in range(self.descriptor.get_ntypes()):
190190
dd = env_mat[type_idx[type_i, ...]]
191191
dd = xp.reshape(
192-
dd, [-1, self.last_dim]
192+
dd, (-1, self.last_dim)
193193
) # typen_atoms * unmasked_nnei, 4
194194
env_mats = {}
195195
env_mats[f"r_{type_i}"] = dd[:, :1]

deepmd/dpmodel/utils/exclude_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def build_type_exclude_mask(
5353
xp = array_api_compat.array_namespace(atype)
5454
nf, natom = atype.shape
5555
return xp.reshape(
56-
xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0),
56+
xp.take(self.type_mask[...], xp.reshape(atype, (-1,)), axis=0),
5757
(nf, natom),
5858
)
5959

deepmd/dpmodel/utils/neighbor_stat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def call(
8282
nall = coord1.shape[1] // 3
8383
coord0 = coord1[:, : nloc * 3]
8484
diff = (
85-
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
86-
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
85+
xp.reshape(coord1, (nframes, -1, 3))[:, None, :, :]
86+
- xp.reshape(coord0, (nframes, -1, 3))[:, :, None, :]
8787
)
8888
assert list(diff.shape) == [nframes, nloc, nall, 3]
8989
# remove the diagonal elements

0 commit comments

Comments
 (0)