Skip to content

Commit a31b886

Browse files
authored
feat(array-api): env mat stat (deepmodeling#4729)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added comprehensive input statistics computation and retrieval methods for multiple descriptor classes, supporting mean and standard deviation calculations with lazy data loading and file caching. - Introduced utilities for environment matrix statistics to enhance descriptor analysis and caching efficiency. - Added new test coverage for descriptor input statistics across multiple backends, incorporating preprocessing of input statistics before evaluation. - **Bug Fixes** - Improved array namespace compatibility in utility functions for consistent array operations. - Updated test logic to exclude specific keys from data comparisons, improving test reliability. - **Chores** - Removed unused imports for cleaner code maintenance. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent cc44b86 commit a31b886

14 files changed

Lines changed: 967 additions & 55 deletions

File tree

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
NetworkCollection,
3030
PairExcludeMask,
3131
)
32+
from deepmd.dpmodel.utils.env_mat_stat import (
33+
EnvMatStatSe,
34+
)
3235
from deepmd.dpmodel.utils.network import (
3336
LayerNorm,
3437
NativeLayer,
@@ -48,6 +51,9 @@
4851
from deepmd.utils.data_system import (
4952
DeepmdDataSystem,
5053
)
54+
from deepmd.utils.env_mat_stat import (
55+
StatItem,
56+
)
5157
from deepmd.utils.finetune import (
5258
get_index_between_two_maps,
5359
map_pair_exclude_types,
@@ -408,10 +414,27 @@ def dim_emb(self):
408414
return self.get_dim_emb()
409415

410416
def compute_input_stats(
411-
self, merged: list[dict], path: Optional[DPPath] = None
412-
) -> NoReturn:
413-
"""Update mean and stddev for descriptor elements."""
414-
raise NotImplementedError
417+
self,
418+
merged: Union[Callable[[], list[dict]], list[dict]],
419+
path: Optional[DPPath] = None,
420+
):
421+
"""
422+
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
423+
424+
Parameters
425+
----------
426+
merged : Union[Callable[[], list[dict]], list[dict]]
427+
- list[dict]: A list of data samples from various data systems.
428+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
429+
originating from the `i`-th data system.
430+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
431+
only when needed. Since the sampling process can be slow and memory-intensive,
432+
the lazy function helps by only sampling once.
433+
path : Optional[DPPath]
434+
The path to the stat file.
435+
436+
"""
437+
return self.se_atten.compute_input_stats(merged, path)
415438

416439
def set_stat_mean_and_stddev(
417440
self,
@@ -842,13 +865,49 @@ def compute_input_stats(
842865
self,
843866
merged: Union[Callable[[], list[dict]], list[dict]],
844867
path: Optional[DPPath] = None,
845-
) -> NoReturn:
846-
"""Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data."""
847-
raise NotImplementedError
868+
) -> None:
869+
"""
870+
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
871+
872+
Parameters
873+
----------
874+
merged : Union[Callable[[], list[dict]], list[dict]]
875+
- list[dict]: A list of data samples from various data systems.
876+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
877+
originating from the `i`-th data system.
878+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
879+
only when needed. Since the sampling process can be slow and memory-intensive,
880+
the lazy function helps by only sampling once.
881+
path : Optional[DPPath]
882+
The path to the stat file.
848883
849-
def get_stats(self) -> NoReturn:
884+
"""
885+
env_mat_stat = EnvMatStatSe(self)
886+
if path is not None:
887+
path = path / env_mat_stat.get_hash()
888+
if path is None or not path.is_dir():
889+
if callable(merged):
890+
# only get data for once
891+
sampled = merged()
892+
else:
893+
sampled = merged
894+
else:
895+
sampled = []
896+
env_mat_stat.load_or_compute_stats(sampled, path)
897+
self.stats = env_mat_stat.stats
898+
mean, stddev = env_mat_stat()
899+
xp = array_api_compat.array_namespace(self.stddev)
900+
if not self.set_davg_zero:
901+
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
902+
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
903+
904+
def get_stats(self) -> dict[str, StatItem]:
850905
"""Get the statistics of the descriptor."""
851-
raise NotImplementedError
906+
if self.stats is None:
907+
raise RuntimeError(
908+
"The statistics of the descriptor has not been computed."
909+
)
910+
return self.stats
852911

853912
def reinit_exclude(
854913
self,

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
3+
Callable,
34
NoReturn,
45
Optional,
56
Union,
@@ -737,10 +738,31 @@ def dim_emb(self):
737738
return self.get_dim_emb()
738739

739740
def compute_input_stats(
740-
self, merged: list[dict], path: Optional[DPPath] = None
741-
) -> NoReturn:
742-
"""Update mean and stddev for descriptor elements."""
743-
raise NotImplementedError
741+
self,
742+
merged: Union[Callable[[], list[dict]], list[dict]],
743+
path: Optional[DPPath] = None,
744+
):
745+
"""
746+
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
747+
748+
Parameters
749+
----------
750+
merged : Union[Callable[[], list[dict]], list[dict]]
751+
- list[dict]: A list of data samples from various data systems.
752+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
753+
originating from the `i`-th data system.
754+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
755+
only when needed. Since the sampling process can be slow and memory-intensive,
756+
the lazy function helps by only sampling once.
757+
path : Optional[DPPath]
758+
The path to the stat file.
759+
760+
"""
761+
descrpt_list = [self.repinit, self.repformers]
762+
if self.use_three_body:
763+
descrpt_list.append(self.repinit_three_body)
764+
for ii, descrpt in enumerate(descrpt_list):
765+
descrpt.compute_input_stats(merged, path)
744766

745767
def set_stat_mean_and_stddev(
746768
self,

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
3-
NoReturn,
43
Optional,
54
Union,
65
)
@@ -444,11 +443,11 @@ def dim_emb(self):
444443
"""Returns the embedding dimension g2."""
445444
return self.get_dim_emb()
446445

447-
def compute_input_stats(
448-
self, merged: list[dict], path: Optional[DPPath] = None
449-
) -> NoReturn:
446+
def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None):
450447
"""Update mean and stddev for descriptor elements."""
451-
raise NotImplementedError
448+
descrpt_list = [self.repflows]
449+
for ii, descrpt in enumerate(descrpt_list):
450+
descrpt.compute_input_stats(merged, path)
452451

453452
def set_stat_mean_and_stddev(
454453
self,

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
33
Callable,
4-
NoReturn,
54
Optional,
65
Union,
76
)
@@ -23,6 +22,9 @@
2322
EnvMat,
2423
PairExcludeMask,
2524
)
25+
from deepmd.dpmodel.utils.env_mat_stat import (
26+
EnvMatStatSe,
27+
)
2628
from deepmd.dpmodel.utils.network import (
2729
NativeLayer,
2830
get_activation_fn,
@@ -33,6 +35,9 @@
3335
from deepmd.dpmodel.utils.seed import (
3436
child_seed,
3537
)
38+
from deepmd.utils.env_mat_stat import (
39+
StatItem,
40+
)
3641
from deepmd.utils.path import (
3742
DPPath,
3843
)
@@ -349,13 +354,49 @@ def compute_input_stats(
349354
self,
350355
merged: Union[Callable[[], list[dict]], list[dict]],
351356
path: Optional[DPPath] = None,
352-
) -> NoReturn:
353-
"""Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data."""
354-
raise NotImplementedError
357+
) -> None:
358+
"""
359+
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
355360
356-
def get_stats(self) -> NoReturn:
361+
Parameters
362+
----------
363+
merged : Union[Callable[[], list[dict]], list[dict]]
364+
- list[dict]: A list of data samples from various data systems.
365+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
366+
originating from the `i`-th data system.
367+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
368+
only when needed. Since the sampling process can be slow and memory-intensive,
369+
the lazy function helps by only sampling once.
370+
path : Optional[DPPath]
371+
The path to the stat file.
372+
373+
"""
374+
env_mat_stat = EnvMatStatSe(self)
375+
if path is not None:
376+
path = path / env_mat_stat.get_hash()
377+
if path is None or not path.is_dir():
378+
if callable(merged):
379+
# only get data for once
380+
sampled = merged()
381+
else:
382+
sampled = merged
383+
else:
384+
sampled = []
385+
env_mat_stat.load_or_compute_stats(sampled, path)
386+
self.stats = env_mat_stat.stats
387+
mean, stddev = env_mat_stat()
388+
xp = array_api_compat.array_namespace(self.stddev)
389+
if not self.set_davg_zero:
390+
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
391+
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
392+
393+
def get_stats(self) -> dict[str, StatItem]:
357394
"""Get the statistics of the descriptor."""
358-
raise NotImplementedError
395+
if self.stats is None:
396+
raise RuntimeError(
397+
"The statistics of the descriptor has not been computed."
398+
)
399+
return self.stats
359400

360401
def reinit_exclude(
361402
self,

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
33
Callable,
4-
NoReturn,
54
Optional,
65
Union,
76
)
@@ -23,6 +22,9 @@
2322
EnvMat,
2423
PairExcludeMask,
2524
)
25+
from deepmd.dpmodel.utils.env_mat_stat import (
26+
EnvMatStatSe,
27+
)
2628
from deepmd.dpmodel.utils.network import (
2729
LayerNorm,
2830
NativeLayer,
@@ -34,6 +36,9 @@
3436
from deepmd.dpmodel.utils.seed import (
3537
child_seed,
3638
)
39+
from deepmd.utils.env_mat_stat import (
40+
StatItem,
41+
)
3742
from deepmd.utils.path import (
3843
DPPath,
3944
)
@@ -370,13 +375,49 @@ def compute_input_stats(
370375
self,
371376
merged: Union[Callable[[], list[dict]], list[dict]],
372377
path: Optional[DPPath] = None,
373-
) -> NoReturn:
374-
"""Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data."""
375-
raise NotImplementedError
378+
) -> None:
379+
"""
380+
Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.
376381
377-
def get_stats(self) -> NoReturn:
382+
Parameters
383+
----------
384+
merged : Union[Callable[[], list[dict]], list[dict]]
385+
- list[dict]: A list of data samples from various data systems.
386+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
387+
originating from the `i`-th data system.
388+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
389+
only when needed. Since the sampling process can be slow and memory-intensive,
390+
the lazy function helps by only sampling once.
391+
path : Optional[DPPath]
392+
The path to the stat file.
393+
394+
"""
395+
env_mat_stat = EnvMatStatSe(self)
396+
if path is not None:
397+
path = path / env_mat_stat.get_hash()
398+
if path is None or not path.is_dir():
399+
if callable(merged):
400+
# only get data for once
401+
sampled = merged()
402+
else:
403+
sampled = merged
404+
else:
405+
sampled = []
406+
env_mat_stat.load_or_compute_stats(sampled, path)
407+
self.stats = env_mat_stat.stats
408+
mean, stddev = env_mat_stat()
409+
xp = array_api_compat.array_namespace(self.stddev)
410+
if not self.set_davg_zero:
411+
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
412+
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
413+
414+
def get_stats(self) -> dict[str, StatItem]:
378415
"""Get the statistics of the descriptor."""
379-
raise NotImplementedError
416+
if self.stats is None:
417+
raise RuntimeError(
418+
"The statistics of the descriptor has not been computed."
419+
)
420+
return self.stats
380421

381422
def reinit_exclude(
382423
self,

0 commit comments

Comments
 (0)