@@ -137,37 +137,15 @@ def __init__(
137137 self .has_aparam = self .tensors ["aparam" ] is not None
138138 self .has_spin = self .ntypes_spin > 0
139139
140- # looks ugly...
141- if self .modifier_type == "dipole_charge" :
142- from deepmd .tf .modifier import (
143- DipoleChargeModifier ,
144- )
140+ from deepmd .tf .modifier import (
141+ BaseModifier ,
142+ )
145143
146- t_mdl_name = self ._get_tensor ("modifier_attr/mdl_name:0" )
147- t_mdl_charge_map = self ._get_tensor ("modifier_attr/mdl_charge_map:0" )
148- t_sys_charge_map = self ._get_tensor ("modifier_attr/sys_charge_map:0" )
149- t_ewald_h = self ._get_tensor ("modifier_attr/ewald_h:0" )
150- t_ewald_beta = self ._get_tensor ("modifier_attr/ewald_beta:0" )
151- [mdl_name , mdl_charge_map , sys_charge_map , ewald_h , ewald_beta ] = run_sess (
152- self .sess ,
153- [
154- t_mdl_name ,
155- t_mdl_charge_map ,
156- t_sys_charge_map ,
157- t_ewald_h ,
158- t_ewald_beta ,
159- ],
160- )
161- mdl_name = mdl_name .decode ("UTF-8" )
162- mdl_charge_map = [int (ii ) for ii in mdl_charge_map .decode ("UTF-8" ).split ()]
163- sys_charge_map = [int (ii ) for ii in sys_charge_map .decode ("UTF-8" ).split ()]
164- self .dm = DipoleChargeModifier (
165- mdl_name ,
166- mdl_charge_map ,
167- sys_charge_map ,
168- ewald_h = ewald_h ,
169- ewald_beta = ewald_beta ,
170- )
144+ self .dm = None
145+ if self .modifier_type is not None :
146+ modifier = BaseModifier .get_class_by_type (self .modifier_type )
147+ modifier_params = modifier .get_params (self )
148+ self .dm = modifier .get_modifier (modifier_params )
171149
172150 def _init_tensors (self ) -> None :
173151 tensor_names = {
@@ -684,7 +662,8 @@ def _get_natoms_and_nframes(
684662 coords : np .ndarray ,
685663 atom_types : list [int ] | np .ndarray ,
686664 ) -> tuple [int , int ]:
687- natoms = len (atom_types [0 ])
665+ atom_types = np .reshape (atom_types , (- 1 ))
666+ natoms = len (atom_types )
688667 if natoms == 0 :
689668 assert coords .size == 0
690669 else :
0 commit comments