|
29 | 29 | NetworkCollection, |
30 | 30 | PairExcludeMask, |
31 | 31 | ) |
| 32 | +from deepmd.dpmodel.utils.env_mat_stat import ( |
| 33 | + EnvMatStatSe, |
| 34 | +) |
32 | 35 | from deepmd.dpmodel.utils.network import ( |
33 | 36 | LayerNorm, |
34 | 37 | NativeLayer, |
|
48 | 51 | from deepmd.utils.data_system import ( |
49 | 52 | DeepmdDataSystem, |
50 | 53 | ) |
| 54 | +from deepmd.utils.env_mat_stat import ( |
| 55 | + StatItem, |
| 56 | +) |
51 | 57 | from deepmd.utils.finetune import ( |
52 | 58 | get_index_between_two_maps, |
53 | 59 | map_pair_exclude_types, |
@@ -408,10 +414,27 @@ def dim_emb(self): |
408 | 414 | return self.get_dim_emb() |
409 | 415 |
|
410 | 416 | def compute_input_stats( |
411 | | - self, merged: list[dict], path: Optional[DPPath] = None |
412 | | - ) -> NoReturn: |
413 | | - """Update mean and stddev for descriptor elements.""" |
414 | | - raise NotImplementedError |
| 417 | + self, |
| 418 | + merged: Union[Callable[[], list[dict]], list[dict]], |
| 419 | + path: Optional[DPPath] = None, |
| 420 | + ): |
| 421 | + """ |
| 422 | + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. |
| 423 | +
|
| 424 | + Parameters |
| 425 | + ---------- |
| 426 | + merged : Union[Callable[[], list[dict]], list[dict]] |
| 427 | + - list[dict]: A list of data samples from various data systems. |
| 428 | + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` |
| 429 | + originating from the `i`-th data system. |
| 430 | + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format |
| 431 | + only when needed. Since the sampling process can be slow and memory-intensive, |
| 432 | + the lazy function helps by only sampling once. |
| 433 | + path : Optional[DPPath] |
| 434 | + The path to the stat file. |
| 435 | +
|
| 436 | + """ |
| 437 | + return self.se_atten.compute_input_stats(merged, path) |
415 | 438 |
|
416 | 439 | def set_stat_mean_and_stddev( |
417 | 440 | self, |
@@ -842,13 +865,49 @@ def compute_input_stats( |
842 | 865 | self, |
843 | 866 | merged: Union[Callable[[], list[dict]], list[dict]], |
844 | 867 | path: Optional[DPPath] = None, |
845 | | - ) -> NoReturn: |
846 | | - """Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.""" |
847 | | - raise NotImplementedError |
| 868 | + ) -> None: |
| 869 | + """ |
| 870 | + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. |
| 871 | +
|
| 872 | + Parameters |
| 873 | + ---------- |
| 874 | + merged : Union[Callable[[], list[dict]], list[dict]] |
| 875 | + - list[dict]: A list of data samples from various data systems. |
| 876 | + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` |
| 877 | + originating from the `i`-th data system. |
| 878 | + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format |
| 879 | + only when needed. Since the sampling process can be slow and memory-intensive, |
| 880 | + the lazy function helps by only sampling once. |
| 881 | + path : Optional[DPPath] |
| 882 | + The path to the stat file. |
848 | 883 |
|
849 | | - def get_stats(self) -> NoReturn: |
| 884 | + """ |
| 885 | + env_mat_stat = EnvMatStatSe(self) |
| 886 | + if path is not None: |
| 887 | + path = path / env_mat_stat.get_hash() |
| 888 | + if path is None or not path.is_dir(): |
| 889 | + if callable(merged): |
| 890 | + # only get data for once |
| 891 | + sampled = merged() |
| 892 | + else: |
| 893 | + sampled = merged |
| 894 | + else: |
| 895 | + sampled = [] |
| 896 | + env_mat_stat.load_or_compute_stats(sampled, path) |
| 897 | + self.stats = env_mat_stat.stats |
| 898 | + mean, stddev = env_mat_stat() |
| 899 | + xp = array_api_compat.array_namespace(self.stddev) |
| 900 | + if not self.set_davg_zero: |
| 901 | + self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) |
| 902 | + self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) |
| 903 | + |
| 904 | + def get_stats(self) -> dict[str, StatItem]: |
850 | 905 | """Get the statistics of the descriptor.""" |
851 | | - raise NotImplementedError |
| 906 | + if self.stats is None: |
| 907 | + raise RuntimeError( |
| 908 | + "The statistics of the descriptor has not been computed." |
| 909 | + ) |
| 910 | + return self.stats |
852 | 911 |
|
853 | 912 | def reinit_exclude( |
854 | 913 | self, |
|
0 commit comments