Skip to content

Commit 124eedd

Browse files
author
Han Wang
committed
Merge remote-tracking branch 'origin/feat-other-full-model' into feat-other-full-model
2 parents 61722b9 + c41515a commit 124eedd

6 files changed

Lines changed: 307 additions & 23 deletions

File tree

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ def __init__(
5353
pair_exclude_types: list[tuple[int, int]] = [],
5454
rcond: float | None = None,
5555
preset_out_bias: dict[str, Array] | None = None,
56+
data_stat_protect: float = 1e-2,
5657
) -> None:
5758
super().__init__()
5859
self.type_map = type_map
5960
self.reinit_atom_exclude(atom_exclude_types)
6061
self.reinit_pair_exclude(pair_exclude_types)
6162
self.rcond = rcond
6263
self.preset_out_bias = preset_out_bias
64+
self.data_stat_protect = data_stat_protect
6365

6466
def init_out_stat(self) -> None:
6567
"""Initialize the output bias."""

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,30 @@ def change_type_map(
232232
)
233233
self.fitting_net.change_type_map(type_map=type_map)
234234

235+
def compute_fitting_input_stat(
236+
self,
237+
sample_merged: Callable[[], list[dict]] | list[dict],
238+
stat_file_path: DPPath | None = None,
239+
) -> None:
240+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
241+
242+
Parameters
243+
----------
244+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
245+
- list[dict]: A list of data samples from various data systems.
246+
Each element, ``merged[i]``, is a data dictionary containing
247+
``keys``: ``np.ndarray`` originating from the ``i``-th data system.
248+
- Callable[[], list[dict]]: A lazy function that returns data samples
249+
in the above format only when needed.
250+
stat_file_path : Optional[DPPath]
251+
The path to the stat file.
252+
"""
253+
self.fitting.compute_input_stats(
254+
sample_merged,
255+
protection=self.data_stat_protect,
256+
stat_file_path=stat_file_path,
257+
)
258+
235259
def serialize(self) -> dict:
236260
dd = super().serialize()
237261
dd.update(

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,14 @@ def change_type_map(
437437
self.ntypes = len(type_map)
438438
self.reinit_exclude(map_atom_exclude_types(self.exclude_types, remap_index))
439439
if has_new_type:
440+
xp = array_api_compat.array_namespace(self.bias_atom_e)
440441
extend_shape = [len(type_map), *list(self.bias_atom_e.shape[1:])]
441-
extend_bias_atom_e = np.zeros(extend_shape, dtype=self.bias_atom_e.dtype)
442-
self.bias_atom_e = np.concatenate(
443-
[self.bias_atom_e, extend_bias_atom_e], axis=0
442+
extend_bias_atom_e = xp.zeros(
443+
extend_shape,
444+
dtype=self.bias_atom_e.dtype,
445+
device=array_api_compat.device(self.bias_atom_e),
444446
)
447+
self.bias_atom_e = xp.concat([self.bias_atom_e, extend_bias_atom_e], axis=0)
445448
self.bias_atom_e = self.bias_atom_e[remap_index]
446449

447450
def __setitem__(self, key: str, value: Any) -> None:

deepmd/dpmodel/fitting/polarizability_fitting.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,21 @@ def change_type_map(
237237
remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map)
238238
super().change_type_map(type_map=type_map)
239239
if has_new_type:
240+
xp = array_api_compat.array_namespace(self.scale)
240241
extend_shape = [len(type_map), *list(self.scale.shape[1:])]
241-
extend_scale = np.ones(extend_shape, dtype=self.scale.dtype)
242-
self.scale = np.concatenate([self.scale, extend_scale], axis=0)
242+
extend_scale = xp.ones(
243+
extend_shape,
244+
dtype=self.scale.dtype,
245+
device=array_api_compat.device(self.scale),
246+
)
247+
self.scale = xp.concat([self.scale, extend_scale], axis=0)
243248
extend_shape = [len(type_map), *list(self.constant_matrix.shape[1:])]
244-
extend_constant_matrix = np.zeros(
245-
extend_shape, dtype=self.constant_matrix.dtype
249+
extend_constant_matrix = xp.zeros(
250+
extend_shape,
251+
dtype=self.constant_matrix.dtype,
252+
device=array_api_compat.device(self.constant_matrix),
246253
)
247-
self.constant_matrix = np.concatenate(
254+
self.constant_matrix = xp.concat(
248255
[self.constant_matrix, extend_constant_matrix], axis=0
249256
)
250257
self.scale = self.scale[remap_index]

deepmd/dpmodel/model/make_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,8 @@ def change_out_bias(
417417
'change-by-statistic' or 'set-by-statistic'.
418418
"""
419419
self.atomic_model.change_out_bias(merged, bias_adjust_mode=bias_adjust_mode)
420+
if bias_adjust_mode == "set-by-statistic":
421+
self.atomic_model.compute_fitting_input_stat(merged)
420422

421423
def _input_type_cast(
422424
self,
@@ -616,12 +618,17 @@ def do_grad_c(
616618
return self.atomic_model.do_grad_c(var_name)
617619

618620
def change_type_map(
619-
self, type_map: list[str], model_with_new_type_stat: Any = None
621+
self, type_map: list[str], model_with_new_type_stat: Any | None = None
620622
) -> None:
621623
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
622624
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
623625
"""
624-
self.atomic_model.change_type_map(type_map=type_map)
626+
self.atomic_model.change_type_map(
627+
type_map=type_map,
628+
model_with_new_type_stat=model_with_new_type_stat.atomic_model
629+
if model_with_new_type_stat is not None
630+
else None,
631+
)
625632

626633
def serialize(self) -> dict:
627634
return self.atomic_model.serialize()

0 commit comments

Comments
 (0)