Skip to content

Commit c2e6c94

Browse files
committed
fix(tf): fix dplr Python inference (#4753)
Fix #4625. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved handling of modifier type decoding for more accurate output. - Corrected model type checking and output dictionary keys during evaluation. - Enhanced input sorting logic for atom types with multiple dimensions. - **Tests** - Updated test data and expected values to reflect recent changes. - Adjusted test logic to align with updated evaluation behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> (cherry picked from commit 2fa4064)
1 parent ab0396d commit c2e6c94

File tree

3 files changed

+52539
-17
lines changed

3 files changed

+52539
-17
lines changed

deepmd/tf/infer/deep_eval.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ def _init_attr(self) -> None:
259259
self.numb_dos = 0
260260
self.tmap = tmap.decode("utf-8").split()
261261
if self.tensors["modifier_type"] is not None:
262-
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[0]
262+
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[
263+
0
264+
].decode()
263265
else:
264266
self.modifier_type = None
265267

@@ -761,15 +763,17 @@ def eval(
761763
odef.name: oo for oo, odef in zip(output, self.output_def.var_defs.values())
762764
}
763765
# ugly!!
764-
if self.modifier_type is not None and isinstance(self.model_type, DeepPot):
766+
if self.modifier_type is not None and issubclass(self.model_type, DeepPot):
765767
if atomic:
766768
raise RuntimeError("modifier does not support atomic modification")
767769
me, mf, mv = self.dm.eval(coords, cells, atom_types)
768-
output = list(output) # tuple to list
769-
e, f, v = output[:3]
770-
output_dict["energy_redu"] += me.reshape(e.shape)
771-
output_dict["energy_deri_r"] += mf.reshape(f.shape)
772-
output_dict["energy_deri_c_redu"] += mv.reshape(v.shape)
770+
output_dict["energy_redu"] += me.reshape(output_dict["energy_redu"].shape)
771+
output_dict["energy_derv_r"] += mf.reshape(
772+
output_dict["energy_derv_r"].shape
773+
)
774+
output_dict["energy_derv_c_redu"] += mv.reshape(
775+
output_dict["energy_derv_c_redu"].shape
776+
)
773777
return output_dict
774778

775779
def _prepare_feed_dict(
@@ -1350,6 +1354,8 @@ def sort_input(
13501354
natoms = atom_type[0].size
13511355
idx_map = np.arange(natoms) # pylint: disable=no-explicit-dtype
13521356
return coord, atom_type, idx_map
1357+
if atom_type.ndim > 1:
1358+
atom_type = atom_type[0]
13531359
if sel_atoms is not None:
13541360
selection = [False] * np.size(atom_type)
13551361
for ii in sel_atoms:

0 commit comments

Comments
 (0)