Skip to content

Commit 4bd8ac8

Browse files
authored
Merge branch 'devel' into D0516_dynamic_sel
2 parents 6b77202 + cb78ec0 commit 4bd8ac8

6 files changed

Lines changed: 52546 additions & 20 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repos:
2929
exclude: ^source/3rdparty
3030
- repo: https://github.com/astral-sh/ruff-pre-commit
3131
# Ruff version.
32-
rev: v0.11.9
32+
rev: v0.11.10
3333
hooks:
3434
- id: ruff
3535
args: ["--fix"]

deepmd/pd/model/model/transform_output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def task_deriv_one(
8282
assert extended_force is not None
8383
extended_force = -extended_force
8484
if do_virial:
85-
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
85+
extended_virial = paddle.einsum(
86+
"...ik,...ij->...ikj", extended_force, extended_coord
87+
)
8688
# the correction sums to zero, which does not contribute to global virial
8789
if do_atomic_virial:
8890
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)

deepmd/pt/model/model/transform_output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def task_deriv_one(
8585
assert extended_force is not None
8686
extended_force = -extended_force
8787
if do_virial:
88-
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
88+
extended_virial = torch.einsum(
89+
"...ik,...ij->...ikj", extended_force, extended_coord
90+
)
8991
# the correction sums to zero, which does not contribute to global virial
9092
if do_atomic_virial:
9193
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)

deepmd/tf/infer/deep_eval.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ def _init_attr(self) -> None:
259259
self.numb_dos = 0
260260
self.tmap = tmap.decode("utf-8").split()
261261
if self.tensors["modifier_type"] is not None:
262-
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[0]
262+
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[
263+
0
264+
].decode()
263265
else:
264266
self.modifier_type = None
265267

@@ -761,15 +763,17 @@ def eval(
761763
odef.name: oo for oo, odef in zip(output, self.output_def.var_defs.values())
762764
}
763765
# ugly!!
764-
if self.modifier_type is not None and isinstance(self.model_type, DeepPot):
766+
if self.modifier_type is not None and issubclass(self.model_type, DeepPot):
765767
if atomic:
766768
raise RuntimeError("modifier does not support atomic modification")
767769
me, mf, mv = self.dm.eval(coords, cells, atom_types)
768-
output = list(output) # tuple to list
769-
e, f, v = output[:3]
770-
output_dict["energy_redu"] += me.reshape(e.shape)
771-
output_dict["energy_deri_r"] += mf.reshape(f.shape)
772-
output_dict["energy_deri_c_redu"] += mv.reshape(v.shape)
770+
output_dict["energy_redu"] += me.reshape(output_dict["energy_redu"].shape)
771+
output_dict["energy_derv_r"] += mf.reshape(
772+
output_dict["energy_derv_r"].shape
773+
)
774+
output_dict["energy_derv_c_redu"] += mv.reshape(
775+
output_dict["energy_derv_c_redu"].shape
776+
)
773777
return output_dict
774778

775779
def _prepare_feed_dict(
@@ -1348,6 +1352,8 @@ def sort_input(
13481352
natoms = atom_type[0].size
13491353
idx_map = np.arange(natoms) # pylint: disable=no-explicit-dtype
13501354
return coord, atom_type, idx_map
1355+
if atom_type.ndim > 1:
1356+
atom_type = atom_type[0]
13511357
if sel_atoms is not None:
13521358
selection = [False] * np.size(atom_type)
13531359
for ii in sel_atoms:

0 commit comments

Comments
 (0)