@@ -909,13 +909,13 @@ def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None)
909909 else :
910910 # For energy and DOS models with all atoms
911911 nloc = natoms [0 ]
912+ # Get output dimension
913+ nout = self ._get_dim_out ()
912914 if self .model_type == "dos" :
913915 # DOS model: output shape [nframes * nloc * numb_dos]
914- nout = self .numb_dos
915916 output_reshaped = tf .reshape (output , [nframes , nloc , nout ])
916917 else :
917918 # Energy model: output shape [nframes * nloc]
918- nout = 1
919919 output_reshaped = tf .reshape (output , [nframes , nloc , 1 ])
920920 atype_for_gather = tf .reshape (atype , [nframes , nloc ])
921921
@@ -1045,28 +1045,26 @@ def serialize(self, suffix: str = "") -> dict:
10451045
10461046 ntypes = len (self .get_type_map ())
10471047
1048+ # Get output dimension
1049+ dim_out = self ._get_dim_out ()
1050+
10481051 # Try to serialize fitting, with fallback for uninitialized variables
10491052 try :
10501053 dict_fit = self .fitting .serialize (suffix = suffix )
10511054 except (AttributeError , TypeError ):
1052- # Fallback: create a minimal dict_fit with just dim_out
1053- dim_out = self ._get_dim_out ()
1054- dict_fit = {"dim_out" : dim_out , "@variables" : {}}
1055+ # Fallback: create a minimal dict_fit
1056+ dict_fit = {"@variables" : {}}
10551057
10561058 # Use the actual out_bias and out_std if they exist, otherwise create defaults
10571059 if self .out_bias is not None :
10581060 out_bias = self .out_bias .copy ()
10591061 else :
1060- out_bias = np .zeros (
1061- [1 , ntypes , dict_fit ["dim_out" ]], dtype = GLOBAL_NP_FLOAT_PRECISION
1062- )
1062+ out_bias = np .zeros ([1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION )
10631063
10641064 if self .out_std is not None :
10651065 out_std = self .out_std .copy ()
10661066 else :
1067- out_std = np .ones (
1068- [1 , ntypes , dict_fit ["dim_out" ]], dtype = GLOBAL_NP_FLOAT_PRECISION
1069- )
1067+ out_std = np .ones ([1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION )
10701068 return {
10711069 "@class" : "Model" ,
10721070 "type" : "standard" ,
0 commit comments