@@ -819,35 +819,37 @@ def get_ntypes(self) -> int:
819819 """Get the number of types."""
820820 return self .ntypes
821821
822+ def _get_dim_out (self ):
823+ """Get output dimension based on model type.
824+
825+ Returns
826+ -------
827+ int
828+ Output dimension
829+ """
830+ if self .model_type == "ener" :
831+ return 1
832+ elif self .model_type in ["dipole" , "polar" ]:
833+ return 3
834+ elif self .model_type == "dos" :
835+ return self .numb_dos
836+ else :
837+ return 1
838+
822839 def init_out_stat (self , suffix : str = "" ) -> None :
823840 """Initialize the output bias and std variables."""
824841 ntypes = self .get_ntypes ()
825-
826- # Determine output dimension based on model type instead of fitting type
827- if hasattr (self , "model_type" ):
828- model_type = self .model_type
829- else :
830- # Fallback to fitting type for compatibility
831- model_type = getattr (self .fitting , "model_type" , "ener" )
832-
833- if model_type == "ener" :
834- dim_out = 1
835- elif model_type in ["dipole" , "polar" ]:
836- dim_out = 3
837- elif model_type == "dos" :
838- dim_out = getattr (self .fitting , "numb_dos" , 1 )
839- else :
840- dim_out = 1
842+ dim_out = self ._get_dim_out ()
841843
842844 # Initialize out_bias and out_std as numpy arrays, preserving existing values if set
843- if hasattr ( self , "out_bias" ) and self .out_bias is not None :
845+ if self .out_bias is not None :
844846 out_bias_data = self .out_bias .copy ()
845847 else :
846848 out_bias_data = np .zeros (
847849 [1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION
848850 )
849851
850- if hasattr ( self , "out_std" ) and self .out_std is not None :
852+ if self .out_std is not None :
851853 out_std_data = self .out_std .copy ()
852854 else :
853855 out_std_data = np .ones (
@@ -907,7 +909,7 @@ def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None)
907909 else :
908910 # For energy and DOS models with all atoms
909911 nloc = natoms [0 ]
910- if hasattr ( self , "numb_dos" ) :
912+ if self . model_type == "dos" :
911913 # DOS model: output shape [nframes * nloc * numb_dos]
912914 nout = self .numb_dos
913915 output_reshaped = tf .reshape (output , [nframes , nloc , nout ])
@@ -1048,41 +1050,12 @@ def serialize(self, suffix: str = "") -> dict:
10481050 dict_fit = self .fitting .serialize (suffix = suffix )
10491051 except (AttributeError , TypeError ):
10501052 # Fallback: create a minimal dict_fit with just dim_out
1051- from deepmd .tf .fit .dipole import (
1052- DipoleFittingSeA ,
1053- )
1054- from deepmd .tf .fit .dos import (
1055- DOSFitting ,
1056- )
1057- from deepmd .tf .fit .ener import (
1058- EnerFitting ,
1059- )
1060- from deepmd .tf .fit .polar import (
1061- PolarFittingSeA ,
1062- )
1063-
1064- if isinstance (self .fitting , EnerFitting ):
1065- dim_out = 1
1066- elif isinstance (self .fitting , (DipoleFittingSeA , PolarFittingSeA )):
1067- dim_out = 3
1068- elif isinstance (self .fitting , DOSFitting ):
1069- dim_out = getattr (self .fitting , "numb_dos" , 1 )
1070- else :
1071- dim_out = 1
1072-
1053+ dim_out = self ._get_dim_out ()
10731054 dict_fit = {"dim_out" : dim_out , "@variables" : {}}
10741055
10751056 # Use the actual out_bias and out_std if they exist, otherwise create defaults
10761057 if self .out_bias is not None :
10771058 out_bias = self .out_bias .copy ()
1078- elif dict_fit .get ("@variables" , {}).get ("bias_atom_e" ) is not None :
1079- # Fallback to converting bias_atom_e for backward compatibility
1080- out_bias = dict_fit ["@variables" ]["bias_atom_e" ].reshape (
1081- [1 , ntypes , dict_fit ["dim_out" ]]
1082- )
1083- dict_fit ["@variables" ]["bias_atom_e" ] = np .zeros_like (
1084- dict_fit ["@variables" ]["bias_atom_e" ]
1085- )
10861059 else :
10871060 out_bias = np .zeros (
10881061 [1 , ntypes , dict_fit ["dim_out" ]], dtype = GLOBAL_NP_FLOAT_PRECISION
0 commit comments