Skip to content

Commit 9c9f463

Browse files
committed
optimize tf modifier in deepeval
1 parent b98f6c5 commit 9c9f463

2 files changed

Lines changed: 45 additions & 31 deletions

File tree

deepmd/tf/infer/deep_eval.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,37 +137,15 @@ def __init__(
137137
self.has_aparam = self.tensors["aparam"] is not None
138138
self.has_spin = self.ntypes_spin > 0
139139

140-
# looks ugly...
141-
if self.modifier_type == "dipole_charge":
142-
from deepmd.tf.modifier import (
143-
DipoleChargeModifier,
144-
)
140+
from deepmd.tf.modifier import (
141+
BaseModifier,
142+
)
145143

146-
t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0")
147-
t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0")
148-
t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0")
149-
t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0")
150-
t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0")
151-
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
152-
self.sess,
153-
[
154-
t_mdl_name,
155-
t_mdl_charge_map,
156-
t_sys_charge_map,
157-
t_ewald_h,
158-
t_ewald_beta,
159-
],
160-
)
161-
mdl_name = mdl_name.decode("UTF-8")
162-
mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()]
163-
sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()]
164-
self.dm = DipoleChargeModifier(
165-
mdl_name,
166-
mdl_charge_map,
167-
sys_charge_map,
168-
ewald_h=ewald_h,
169-
ewald_beta=ewald_beta,
170-
)
144+
self.dm = None
145+
if self.modifier_type is not None:
146+
modifier = BaseModifier.get_class_by_type(self.modifier_type)
147+
modifier_params = modifier.get_params(self)
148+
self.dm = modifier.get_modifier(modifier_params)
171149

172150
def _init_tensors(self) -> None:
173151
tensor_names = {
@@ -684,7 +662,8 @@ def _get_natoms_and_nframes(
684662
coords: np.ndarray,
685663
atom_types: list[int] | np.ndarray,
686664
) -> tuple[int, int]:
687-
natoms = len(atom_types[0])
665+
atom_types = np.reshape(atom_types, (-1))
666+
natoms = len(atom_types)
688667
if natoms == 0:
689668
assert coords.size == 0
690669
else:

deepmd/tf/modifier/dipole_charge.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
op_module,
1414
tf,
1515
)
16+
from deepmd.tf.infer import (
17+
DeepEval,
18+
)
1619
from deepmd.tf.infer.deep_dipole import DeepDipoleOld as DeepDipole
1720
from deepmd.tf.infer.ewald_recp import (
1821
EwaldRecp,
@@ -487,3 +490,35 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
487490
data["force"] -= tot_f.reshape(data["force"].shape)
488491
if "find_virial" in data and data["find_virial"] == 1.0:
489492
data["virial"] -= tot_v.reshape(data["virial"].shape)
493+
494+
@staticmethod
495+
def get_params(model: DeepEval):
496+
t_mdl_name = model._get_tensor("modifier_attr/mdl_name:0")
497+
t_mdl_charge_map = model._get_tensor("modifier_attr/mdl_charge_map:0")
498+
t_sys_charge_map = model._get_tensor("modifier_attr/sys_charge_map:0")
499+
t_ewald_h = model._get_tensor("modifier_attr/ewald_h:0")
500+
t_ewald_beta = model._get_tensor("modifier_attr/ewald_beta:0")
501+
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
502+
model.sess,
503+
[
504+
t_mdl_name,
505+
t_mdl_charge_map,
506+
t_sys_charge_map,
507+
t_ewald_h,
508+
t_ewald_beta,
509+
],
510+
)
511+
model_charge_map = [
512+
int(float(ii)) for ii in mdl_charge_map.decode("UTF-8").split()
513+
]
514+
sys_charge_map = [
515+
int(float(ii)) for ii in sys_charge_map.decode("UTF-8").split()
516+
]
517+
modifier_params = {
518+
"model_name": mdl_name.decode("UTF-8"),
519+
"model_charge_map": model_charge_map,
520+
"sys_charge_map": sys_charge_map,
521+
"ewald_h": ewald_h,
522+
"ewald_beta": ewald_beta,
523+
}
524+
return modifier_params

0 commit comments

Comments
 (0)