@@ -135,7 +135,7 @@ def _post_process_stat(
135135 """Post process the statistics.
136136
137137 For global statistics, we do not have the std for each type of atoms,
138- thus fake the output std by ones for all the types.
138+ thus broadcast the global std to all the types.
139139 If the shape of out_std is already the same as out_bias,
140140 we do not need to do anything.
141141 """
@@ -144,7 +144,9 @@ def _post_process_stat(
144144 if vv .shape == out_std [kk ].shape :
145145 new_std [kk ] = out_std [kk ]
146146 else :
147- new_std [kk ] = np .ones_like (vv )
147+ ntypes = vv .shape [0 ]
148+ reps = [ntypes ] + [1 ] * (vv .ndim - 1 )
149+ new_std [kk ] = np .tile (out_std [kk ], reps )
148150 return out_bias , new_std
149151
150152
@@ -481,6 +483,7 @@ def _compute_output_stats_global(
481483 merged_natoms [kk ],
482484 assigned_bias = assigned_atom_ener [kk ],
483485 rcond = rcond ,
486+ intensive = intensive ,
484487 )
485488 else :
486489 # this key does not have global labels, skip it.
@@ -491,26 +494,25 @@ def _compute_output_stats_global(
491494 def rmse (x : np .ndarray ) -> float :
492495 return np .sqrt (np .mean (np .square (x )))
493496
494- if model_pred is None :
495- unbias_e = {
496- kk : merged_natoms [kk ] @ bias_atom_e [kk ].reshape (ntypes , - 1 )
497- for kk in bias_atom_e .keys ()
498- }
499- else :
500- unbias_e = {
501- kk : model_pred [kk ].reshape (nf [kk ], - 1 )
502- + merged_natoms [kk ] @ bias_atom_e [kk ].reshape (ntypes , - 1 )
503- for kk in bias_atom_e .keys ()
504- }
505- atom_numbs = {kk : merged_natoms [kk ].sum (- 1 ) for kk in bias_atom_e .keys ()}
497+ unbias_e = {}
498+ for kk in bias_atom_e .keys ():
499+ coeffs = merged_natoms [kk ]
500+ if intensive :
501+ total_atoms = coeffs .sum (axis = 1 , keepdims = True )
502+ coeffs = coeffs / total_atoms
503+ recon = coeffs @ bias_atom_e [kk ].reshape (ntypes , - 1 )
504+ if model_pred is not None :
505+ recon += model_pred [kk ].reshape (nf [kk ], - 1 )
506+ unbias_e [kk ] = recon
506507
507508 for kk in bias_atom_e .keys ():
508- rmse_ae = rmse (
509- (unbias_e [kk ].reshape (nf [kk ], - 1 ) - merged_output [kk ].reshape (nf [kk ], - 1 ))
510- / atom_numbs [kk ][:, None ]
511- )
509+ diff = unbias_e [kk ].reshape (nf [kk ], - 1 ) - merged_output [kk ].reshape (nf [kk ], - 1 )
510+ if not intensive :
511+ diff /= merged_natoms [kk ].sum (axis = - 1 , keepdims = True )
512+ rmse_ae = rmse (diff )
513+ stat_type = "per atom " if not intensive else ""
512514 log .info (
513- f"RMSE of { kk } per atom after linear regression is: { rmse_ae } in the unit of { kk } ."
515+ f"RMSE of { kk } { stat_type } after linear regression is: { rmse_ae } in the unit of { kk } ."
514516 )
515517 return bias_atom_e , std_atom_e
516518
0 commit comments