11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ import functools
23from typing import (
3- Callable ,
44 Optional ,
55 Union ,
66)
@@ -319,6 +319,10 @@ def apply_out_stat(
319319 The atom types. nf x nloc
320320
321321 """
322+ out_bias , out_std = self ._fetch_out_stat (self .bias_keys )
323+ for kk in self .bias_keys :
324+ # nf x nloc x odims, out_bias: ntypes x odims
325+ ret [kk ] = ret [kk ] + out_bias [kk ][atype ]
322326 return ret
323327
324328 @staticmethod
@@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool:
464468 """
465469 return False
466470
467- def compute_or_load_out_stat (
468- self ,
469- merged : Union [Callable [[], list [dict ]], list [dict ]],
470- stat_file_path : Optional [DPPath ] = None ,
471- ) -> None :
472- """
473- Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
474-
475- Parameters
476- ----------
477- merged : Union[Callable[[], list[dict]], list[dict]]
478- - list[dict]: A list of data samples from various data systems.
479- Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
480- originating from the `i`-th data system.
481- - Callable[[], list[dict]]: A lazy function that returns data samples in the above format
482- only when needed. Since the sampling process can be slow and memory-intensive,
483- the lazy function helps by only sampling once.
484- stat_file_path : Optional[DPPath]
485- The path to the stat file.
486-
487- """
488- for md in self .models :
489- md .compute_or_load_out_stat (merged , stat_file_path )
490-
491471 def compute_or_load_stat (
492472 self ,
493473 sampled_func ,
494474 stat_file_path : Optional [DPPath ] = None ,
475+ compute_or_load_out_stat : bool = True ,
495476 ) -> None :
496477 """
497478 Compute or load the statistics parameters of the model,
@@ -507,9 +488,34 @@ def compute_or_load_stat(
507488 The lazy sampled function to get data frames from different data systems.
508489 stat_file_path
509490 The dictionary of paths to the statistics files.
491+ compute_or_load_out_stat : bool
492+ Whether to compute the output statistics.
493+ If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
510494 """
511495 for md in self .models :
512- md .compute_or_load_stat (sampled_func , stat_file_path )
496+ md .compute_or_load_stat (
497+ sampled_func , stat_file_path , compute_or_load_out_stat = False
498+ )
499+
500+ if stat_file_path is not None and self .type_map is not None :
501+ # descriptors and fitting net with different type_map
502+ # should not share the same parameters
503+ stat_file_path /= " " .join (self .type_map )
504+
505+ @functools .lru_cache
506+ def wrapped_sampler ():
507+ sampled = sampled_func ()
508+ if self .pair_excl is not None :
509+ pair_exclude_types = self .pair_excl .get_exclude_types ()
510+ for sample in sampled :
511+ sample ["pair_exclude_types" ] = list (pair_exclude_types )
512+ if self .atom_excl is not None :
513+ atom_exclude_types = self .atom_excl .get_exclude_types ()
514+ for sample in sampled :
515+ sample ["atom_exclude_types" ] = list (atom_exclude_types )
516+ return sampled
517+
518+ self .compute_or_load_out_stat (wrapped_sampler , stat_file_path )
513519
514520
515521class DPZBLLinearEnergyAtomicModel (LinearEnergyAtomicModel ):
0 commit comments