Skip to content

Commit b5740e5

Browse files
authored
Merge branch 'devel' into D0812_devel_head_alias
2 parents b9bd9ed + accc331 commit b5740e5

20 files changed

Lines changed: 590 additions & 68 deletions

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from deepmd.dpmodel.fitting.invar_fitting import (
1313
InvarFitting,
1414
)
15+
from deepmd.dpmodel.output_def import (
16+
FittingOutputDef,
17+
OutputVariableDef,
18+
)
1519
from deepmd.utils.version import (
1620
check_version_compatibility,
1721
)
@@ -108,6 +112,20 @@ def __init__(
108112
type_map=type_map,
109113
)
110114

115+
def output_def(self) -> FittingOutputDef:
116+
return FittingOutputDef(
117+
[
118+
OutputVariableDef(
119+
self.var_name,
120+
[self.dim_out],
121+
reducible=True,
122+
r_differentiable=False,
123+
c_differentiable=False,
124+
intensive=self.intensive,
125+
),
126+
]
127+
)
128+
111129
@classmethod
112130
def deserialize(cls, data: dict) -> "PropertyFittingNet":
113131
data = data.copy()

deepmd/dpmodel/model/dp_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ def update_sel(
4545
train_data, type_map, local_jdata["descriptor"]
4646
)
4747
return local_jdata_cpy, min_nbor_dist
48+
49+
def get_fitting_net(self):
50+
"""Get the fitting network."""
51+
return self.atomic_model.fitting

deepmd/dpmodel/model/make_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def forward_common_atomic(
355355
self.atomic_output_def(),
356356
extended_coord,
357357
do_atomic_virial=do_atomic_virial,
358+
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
358359
)
359360

360361
forward_lower = call_lower

deepmd/dpmodel/model/property_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ def __init__(
2525
) -> None:
2626
DPModelCommon.__init__(self)
2727
DPPropertyModel_.__init__(self, *args, **kwargs)
28+
29+
def get_var_name(self) -> str:
30+
"""Get the name of the property."""
31+
return self.get_fitting_net().var_name

deepmd/dpmodel/model/transform_output.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
from typing import (
4+
Optional,
5+
)
6+
37
import array_api_compat
48
import numpy as np
59

@@ -24,6 +28,7 @@ def fit_output_to_model_output(
2428
fit_output_def: FittingOutputDef,
2529
coord_ext: np.ndarray,
2630
do_atomic_virial: bool = False,
31+
mask: Optional[np.ndarray] = None,
2732
) -> dict[str, np.ndarray]:
2833
"""Transform the output of the fitting network to
2934
the model output.
@@ -38,9 +43,19 @@ def fit_output_to_model_output(
3843
if vdef.reducible:
3944
kk_redu = get_reduce_name(kk)
4045
# cast to energy prec before reduction
41-
model_ret[kk_redu] = xp.sum(
42-
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
43-
)
46+
if vdef.intensive:
47+
if mask is not None:
48+
model_ret[kk_redu] = xp.sum(
49+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
50+
) / np.sum(mask, axis=-1, keepdims=True)
51+
else:
52+
model_ret[kk_redu] = xp.mean(
53+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
54+
)
55+
else:
56+
model_ret[kk_redu] = xp.sum(
57+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
58+
)
4459
if vdef.r_differentiable:
4560
kk_derv_r, kk_derv_c = get_deriv_name(kk)
4661
# name-holders

deepmd/jax/model/base_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ def forward_common_atomic(
4646
atom_axis = -(len(shap) + 1)
4747
if vdef.reducible:
4848
kk_redu = get_reduce_name(kk)
49-
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
49+
if vdef.intensive:
50+
mask = atomic_ret["mask"] if "mask" in atomic_ret else None
51+
if mask is not None:
52+
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) / jnp.sum(
53+
mask, axis=-1, keepdims=True
54+
)
55+
else:
56+
model_predict[kk_redu] = jnp.mean(vv, axis=atom_axis)
57+
else:
58+
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
5059
kk_derv_r, kk_derv_c = get_deriv_name(kk)
5160
if vdef.r_differentiable:
5261

deepmd/pd/train/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,9 @@ def log_loss_valid(_task_key="Default"):
918918
self.t0 = current_time
919919
if self.rank == 0 and self.timing_in_training:
920920
eta = int(
921-
(self.num_steps - display_step_id) / self.disp_freq * train_time
921+
(self.num_steps - display_step_id)
922+
/ min(self.disp_freq, display_step_id - self.start_step)
923+
* train_time
922924
)
923925
log.info(
924926
format_training_message(

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,21 +363,25 @@ def compute_or_load_stat(
363363
self,
364364
merged: Union[Callable[[], list[dict]], list[dict]],
365365
stat_file_path: Optional[DPPath] = None,
366+
compute_or_load_out_stat: bool = True,
366367
) -> NoReturn:
367368
"""
368-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
369+
Compute or load the statistics parameters of the model,
370+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
371+
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
372+
and saved in the `stat_file_path`(s).
373+
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
374+
and load the calculated statistics parameters.
369375
370376
Parameters
371377
----------
372-
merged : Union[Callable[[], list[dict]], list[dict]]
373-
- list[dict]: A list of data samples from various data systems.
374-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
375-
originating from the `i`-th data system.
376-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
377-
only when needed. Since the sampling process can be slow and memory-intensive,
378-
the lazy function helps by only sampling once.
379-
stat_file_path : Optional[DPPath]
380-
The path to the stat file.
378+
merged
379+
The lazy sampled function to get data frames from different data systems.
380+
stat_file_path
381+
The dictionary of paths to the statistics files.
382+
compute_or_load_out_stat : bool
383+
Whether to compute the output statistics.
384+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
381385
382386
"""
383387
raise NotImplementedError

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def compute_or_load_stat(
285285
self,
286286
sampled_func,
287287
stat_file_path: Optional[DPPath] = None,
288+
compute_or_load_out_stat: bool = True,
288289
) -> None:
289290
"""
290291
Compute or load the statistics parameters of the model,
@@ -300,6 +301,9 @@ def compute_or_load_stat(
300301
The lazy sampled function to get data frames from different data systems.
301302
stat_file_path
302303
The dictionary of paths to the statistics files.
304+
compute_or_load_out_stat : bool
305+
Whether to compute the output statistics.
306+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
303307
"""
304308
if stat_file_path is not None and self.type_map is not None:
305309
# descriptors and fitting net with different type_map
@@ -323,7 +327,8 @@ def wrapped_sampler():
323327
self.fitting_net.compute_input_stats(
324328
wrapped_sampler, protection=self.data_stat_protect
325329
)
326-
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
330+
if compute_or_load_out_stat:
331+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
327332

328333
def get_dim_fparam(self) -> int:
329334
"""Get the number (dimension) of frame parameters of this atomic model."""

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
23
from typing import (
3-
Callable,
44
Optional,
55
Union,
66
)
@@ -319,6 +319,10 @@ def apply_out_stat(
319319
The atom types. nf x nloc
320320
321321
"""
322+
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
323+
for kk in self.bias_keys:
324+
# nf x nloc x odims, out_bias: ntypes x odims
325+
ret[kk] = ret[kk] + out_bias[kk][atype]
322326
return ret
323327

324328
@staticmethod
@@ -464,34 +468,11 @@ def is_aparam_nall(self) -> bool:
464468
"""
465469
return False
466470

467-
def compute_or_load_out_stat(
468-
self,
469-
merged: Union[Callable[[], list[dict]], list[dict]],
470-
stat_file_path: Optional[DPPath] = None,
471-
) -> None:
472-
"""
473-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
474-
475-
Parameters
476-
----------
477-
merged : Union[Callable[[], list[dict]], list[dict]]
478-
- list[dict]: A list of data samples from various data systems.
479-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
480-
originating from the `i`-th data system.
481-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
482-
only when needed. Since the sampling process can be slow and memory-intensive,
483-
the lazy function helps by only sampling once.
484-
stat_file_path : Optional[DPPath]
485-
The path to the stat file.
486-
487-
"""
488-
for md in self.models:
489-
md.compute_or_load_out_stat(merged, stat_file_path)
490-
491471
def compute_or_load_stat(
492472
self,
493473
sampled_func,
494474
stat_file_path: Optional[DPPath] = None,
475+
compute_or_load_out_stat: bool = True,
495476
) -> None:
496477
"""
497478
Compute or load the statistics parameters of the model,
@@ -507,9 +488,34 @@ def compute_or_load_stat(
507488
The lazy sampled function to get data frames from different data systems.
508489
stat_file_path
509490
The dictionary of paths to the statistics files.
491+
compute_or_load_out_stat : bool
492+
Whether to compute the output statistics.
493+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
510494
"""
511495
for md in self.models:
512-
md.compute_or_load_stat(sampled_func, stat_file_path)
496+
md.compute_or_load_stat(
497+
sampled_func, stat_file_path, compute_or_load_out_stat=False
498+
)
499+
500+
if stat_file_path is not None and self.type_map is not None:
501+
# descriptors and fitting net with different type_map
502+
# should not share the same parameters
503+
stat_file_path /= " ".join(self.type_map)
504+
505+
@functools.lru_cache
506+
def wrapped_sampler():
507+
sampled = sampled_func()
508+
if self.pair_excl is not None:
509+
pair_exclude_types = self.pair_excl.get_exclude_types()
510+
for sample in sampled:
511+
sample["pair_exclude_types"] = list(pair_exclude_types)
512+
if self.atom_excl is not None:
513+
atom_exclude_types = self.atom_excl.get_exclude_types()
514+
for sample in sampled:
515+
sample["atom_exclude_types"] = list(atom_exclude_types)
516+
return sampled
517+
518+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
513519

514520

515521
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):

0 commit comments

Comments
 (0)