@@ -157,7 +157,7 @@ def _post_process_stat(
157157 """Post process the statistics.
158158
159159 For global statistics, we do not have the std for each type of atoms,
160- thus fake the output std by ones for all the types.
160+ thus broadcast the global std to all the types.
161161 If the shape of out_std is already the same as out_bias,
162162 we do not need to do anything.
163163
@@ -167,7 +167,9 @@ def _post_process_stat(
167167 if vv .shape == out_std [kk ].shape :
168168 new_std [kk ] = out_std [kk ]
169169 else :
170- new_std [kk ] = np .ones_like (vv )
170+ ntypes = vv .shape [0 ]
171+ reps = [ntypes ] + [1 ] * (vv .ndim - 1 )
172+ new_std [kk ] = np .tile (out_std [kk ], reps )
171173 return out_bias , new_std
172174
173175
@@ -517,6 +519,7 @@ def _compute_output_stats_global(
517519 merged_natoms [kk ],
518520 assigned_bias = assigned_atom_ener [kk ],
519521 rcond = rcond ,
522+ intensive = intensive ,
520523 )
521524 else :
522525 # this key does not have global labels, skip it.
@@ -525,29 +528,28 @@ def _compute_output_stats_global(
525528
526529 # unbias_e is only used for print rmse
527530
528- if model_pred is None :
529- unbias_e = {
530- kk : merged_natoms [kk ] @ bias_atom_e [kk ].reshape (ntypes , - 1 )
531- for kk in bias_atom_e .keys ()
532- }
533- else :
534- unbias_e = {
535- kk : model_pred [kk ].reshape (nf [kk ], - 1 )
536- + merged_natoms [kk ] @ bias_atom_e [kk ].reshape (ntypes , - 1 )
537- for kk in bias_atom_e .keys ()
538- }
539- atom_numbs = {kk : merged_natoms [kk ].sum (- 1 ) for kk in bias_atom_e .keys ()}
531+ unbias_e = {}
532+ for kk in bias_atom_e .keys ():
533+ coeffs = merged_natoms [kk ]
534+ if intensive :
535+ total_atoms = coeffs .sum (axis = 1 , keepdims = True )
536+ coeffs = coeffs / total_atoms
537+ recon = coeffs @ bias_atom_e [kk ].reshape (ntypes , - 1 )
538+ if model_pred is not None :
539+ recon += model_pred [kk ].reshape (nf [kk ], - 1 )
540+ unbias_e [kk ] = recon
540541
541542 def rmse (x : np .ndarray ) -> float :
542543 return np .sqrt (np .mean (np .square (x )))
543544
544545 for kk in bias_atom_e .keys ():
545- rmse_ae = rmse (
546- (unbias_e [kk ].reshape (nf [kk ], - 1 ) - merged_output [kk ].reshape (nf [kk ], - 1 ))
547- / atom_numbs [kk ][:, None ]
548- )
546+ diff = unbias_e [kk ].reshape (nf [kk ], - 1 ) - merged_output [kk ].reshape (nf [kk ], - 1 )
547+ if not intensive :
548+ diff /= merged_natoms [kk ].sum (axis = - 1 , keepdims = True )
549+ rmse_ae = rmse (diff )
550+ stat_type = "per atom " if not intensive else ""
549551 log .info (
550- f"RMSE of { kk } per atom after linear regression is: { rmse_ae } in the unit of { kk } ."
552+ f"RMSE of { kk } { stat_type } after linear regression is: { rmse_ae } in the unit of { kk } ."
551553 )
552554 return bias_atom_e , std_atom_e
553555
0 commit comments