Skip to content

Commit 9c90d95

Browse files
committed
Add compute_out_stat flag to stat computation methods
Introduces a compute_out_stat parameter to compute_or_load_stat methods in BaseAtomicModel, DPAtomicModel, LinearEnergyAtomicModel, and PairTabAtomicModel. This allows conditional computation of output statistics, improving flexibility and control over the statistics computation process.
1 parent e446004 commit 9c90d95

4 files changed

Lines changed: 36 additions & 28 deletions

File tree

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def compute_or_load_stat(
363363
self,
364364
merged: Union[Callable[[], list[dict]], list[dict]],
365365
stat_file_path: Optional[DPPath] = None,
366+
compute_out_stat: bool = True,
366367
) -> NoReturn:
367368
"""
368369
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def compute_or_load_stat(
285285
self,
286286
sampled_func,
287287
stat_file_path: Optional[DPPath] = None,
288+
compute_out_stat: bool = True,
288289
) -> None:
289290
"""
290291
Compute or load the statistics parameters of the model,
@@ -323,7 +324,8 @@ def wrapped_sampler():
323324
self.fitting_net.compute_input_stats(
324325
wrapped_sampler, protection=self.data_stat_protect
325326
)
326-
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
327+
if compute_out_stat:
328+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
327329

328330
def get_dim_fparam(self) -> int:
329331
"""Get the number (dimension) of frame parameters of this atomic model."""

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
23
from typing import (
3-
Callable,
44
Optional,
55
Union,
66
)
@@ -319,6 +319,10 @@ def apply_out_stat(
319319
The atom types. nf x nloc
320320
321321
"""
322+
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
323+
for kk in self.bias_keys:
324+
# nf x nloc x odims, out_bias: ntypes x odims
325+
ret[kk] = ret[kk] + out_bias[kk][atype]
322326
return ret
323327

324328
@staticmethod
@@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool:
464468
"""
465469
return False
466470

467-
def compute_or_load_out_stat(
468-
self,
469-
merged: Union[Callable[[], list[dict]], list[dict]],
470-
stat_file_path: Optional[DPPath] = None,
471-
) -> None:
472-
"""
473-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
474-
475-
Parameters
476-
----------
477-
merged : Union[Callable[[], list[dict]], list[dict]]
478-
- list[dict]: A list of data samples from various data systems.
479-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
480-
originating from the `i`-th data system.
481-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
482-
only when needed. Since the sampling process can be slow and memory-intensive,
483-
the lazy function helps by only sampling once.
484-
stat_file_path : Optional[DPPath]
485-
The path to the stat file.
486-
487-
"""
488-
for md in self.models:
489-
md.compute_or_load_out_stat(merged, stat_file_path)
490-
491471
def compute_or_load_stat(
492472
self,
493473
sampled_func,
494474
stat_file_path: Optional[DPPath] = None,
475+
compute_out_stat: bool = True,
495476
) -> None:
496477
"""
497478
Compute or load the statistics parameters of the model,
@@ -509,7 +490,29 @@ def compute_or_load_stat(
509490
The dictionary of paths to the statistics files.
510491
"""
511492
for md in self.models:
512-
md.compute_or_load_stat(sampled_func, stat_file_path)
493+
md.compute_or_load_stat(
494+
sampled_func, stat_file_path, compute_out_stat=False
495+
)
496+
497+
if stat_file_path is not None and self.type_map is not None:
498+
# descriptors and fitting net with different type_map
499+
# should not share the same parameters
500+
stat_file_path /= " ".join(self.type_map)
501+
502+
@functools.lru_cache
503+
def wrapped_sampler():
504+
sampled = sampled_func()
505+
if self.pair_excl is not None:
506+
pair_exclude_types = self.pair_excl.get_exclude_types()
507+
for sample in sampled:
508+
sample["pair_exclude_types"] = list(pair_exclude_types)
509+
if self.atom_excl is not None:
510+
atom_exclude_types = self.atom_excl.get_exclude_types()
511+
for sample in sampled:
512+
sample["atom_exclude_types"] = list(atom_exclude_types)
513+
return sampled
514+
515+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
513516

514517

515518
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):

deepmd/pt/model/atomic_model/pairtab_atomic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def compute_or_load_stat(
226226
self,
227227
merged: Union[Callable[[], list[dict]], list[dict]],
228228
stat_file_path: Optional[DPPath] = None,
229+
compute_out_stat: bool = True,
229230
) -> None:
230231
"""
231232
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
@@ -243,7 +244,8 @@ def compute_or_load_stat(
243244
The path to the stat file.
244245
245246
"""
246-
self.compute_or_load_out_stat(merged, stat_file_path)
247+
if compute_out_stat:
248+
self.compute_or_load_out_stat(merged, stat_file_path)
247249

248250
def forward_atomic(
249251
self,

0 commit comments

Comments
 (0)