@@ -259,7 +259,9 @@ def _init_attr(self) -> None:
259259 self .numb_dos = 0
260260 self .tmap = tmap .decode ("utf-8" ).split ()
261261 if self .tensors ["modifier_type" ] is not None :
262- self .modifier_type = run_sess (self .sess , [self .tensors ["modifier_type" ]])[0 ]
262+ self .modifier_type = run_sess (self .sess , [self .tensors ["modifier_type" ]])[
263+ 0
264+ ].decode ()
263265 else :
264266 self .modifier_type = None
265267
@@ -761,15 +763,17 @@ def eval(
761763 odef .name : oo for oo , odef in zip (output , self .output_def .var_defs .values ())
762764 }
763765 # ugly!!
764- if self .modifier_type is not None and isinstance (self .model_type , DeepPot ):
766+ if self .modifier_type is not None and issubclass (self .model_type , DeepPot ):
765767 if atomic :
766768 raise RuntimeError ("modifier does not support atomic modification" )
767769 me , mf , mv = self .dm .eval (coords , cells , atom_types )
768- output = list (output ) # tuple to list
769- e , f , v = output [:3 ]
770- output_dict ["energy_redu" ] += me .reshape (e .shape )
771- output_dict ["energy_deri_r" ] += mf .reshape (f .shape )
772- output_dict ["energy_deri_c_redu" ] += mv .reshape (v .shape )
770+ output_dict ["energy_redu" ] += me .reshape (output_dict ["energy_redu" ].shape )
771+ output_dict ["energy_derv_r" ] += mf .reshape (
772+ output_dict ["energy_derv_r" ].shape
773+ )
774+ output_dict ["energy_derv_c_redu" ] += mv .reshape (
775+ output_dict ["energy_derv_c_redu" ].shape
776+ )
773777 return output_dict
774778
775779 def _prepare_feed_dict (
@@ -1348,6 +1352,8 @@ def sort_input(
13481352 natoms = atom_type [0 ].size
13491353 idx_map = np .arange (natoms ) # pylint: disable=no-explicit-dtype
13501354 return coord , atom_type , idx_map
1355+ if atom_type .ndim > 1 :
1356+ atom_type = atom_type [0 ]
13511357 if sel_atoms is not None :
13521358 selection = [False ] * np .size (atom_type )
13531359 for ii in sel_atoms :
0 commit comments