Skip to content

Commit 07483a7

Browse files
Merge branch 'devel' into 1108_default_fparam_stat
Signed-off-by: Chenqqian Zhang <100290172+Chengqian-Zhang@users.noreply.github.com>
2 parents d6120a0 + e98dc5a commit 07483a7

10 files changed

Lines changed: 256 additions & 10 deletions

File tree

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from typing import (
66
Any,
7+
Callable,
78
Optional,
89
Union,
910
)
@@ -221,6 +222,71 @@ def __init__(
221222
],
222223
)
223224

225+
def compute_input_stats(
226+
self,
227+
merged: Union[Callable[[], list[dict]], list[dict]],
228+
protection: float = 1e-2,
229+
) -> None:
230+
"""
231+
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
232+
233+
Parameters
234+
----------
235+
merged : Union[Callable[[], list[dict]], list[dict]]
236+
- list[dict]: A list of data samples from various data systems.
237+
Each element, `merged[i]`, is a data dictionary containing `keys`: `numpy.ndarray`
238+
originating from the `i`-th data system.
239+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
240+
only when needed. Since the sampling process can be slow and memory-intensive,
241+
the lazy function helps by only sampling once.
242+
protection : float
243+
Divided-by-zero protection
244+
"""
245+
if self.numb_fparam == 0 and self.numb_aparam == 0:
246+
# skip data statistics
247+
return
248+
if callable(merged):
249+
sampled = merged()
250+
else:
251+
sampled = merged
252+
# stat fparam
253+
if self.numb_fparam > 0:
254+
cat_data = np.concatenate([frame["fparam"] for frame in sampled], axis=0)
255+
cat_data = np.reshape(cat_data, [-1, self.numb_fparam])
256+
fparam_avg = np.mean(cat_data, axis=0)
257+
fparam_std = np.std(cat_data, axis=0, ddof=0) # ddof=0 for population std
258+
fparam_std = np.where(
259+
fparam_std < protection,
260+
np.array(protection, dtype=fparam_std.dtype),
261+
fparam_std,
262+
)
263+
fparam_inv_std = 1.0 / fparam_std
264+
self.fparam_avg = fparam_avg.astype(self.fparam_avg.dtype)
265+
self.fparam_inv_std = fparam_inv_std.astype(self.fparam_inv_std.dtype)
266+
# stat aparam
267+
if self.numb_aparam > 0:
268+
sys_sumv = []
269+
sys_sumv2 = []
270+
sys_sumn = []
271+
for ss_ in [frame["aparam"] for frame in sampled]:
272+
ss = np.reshape(ss_, [-1, self.numb_aparam])
273+
sys_sumv.append(np.sum(ss, axis=0))
274+
sys_sumv2.append(np.sum(ss * ss, axis=0))
275+
sys_sumn.append(ss.shape[0])
276+
sumv = np.sum(np.stack(sys_sumv), axis=0)
277+
sumv2 = np.sum(np.stack(sys_sumv2), axis=0)
278+
sumn = sum(sys_sumn)
279+
aparam_avg = sumv / sumn
280+
aparam_std = np.sqrt(sumv2 / sumn - (sumv / sumn) ** 2)
281+
aparam_std = np.where(
282+
aparam_std < protection,
283+
np.array(protection, dtype=aparam_std.dtype),
284+
aparam_std,
285+
)
286+
aparam_inv_std = 1.0 / aparam_std
287+
self.aparam_avg = aparam_avg.astype(self.aparam_avg.dtype)
288+
self.aparam_inv_std = aparam_inv_std.astype(self.aparam_inv_std.dtype)
289+
224290
@abstractmethod
225291
def _net_out_dim(self) -> int:
226292
"""Set the FittingNet output dim."""

deepmd/pd/model/atomic_model/base_atomic_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,24 @@ def change_out_bias(
515515
else:
516516
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
517517

518+
def compute_fitting_input_stat(
519+
self,
520+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
521+
) -> None:
522+
"""Compute the input statistics (e.g. mean and stddev) for the atomic model from packed data.
523+
524+
Parameters
525+
----------
526+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
527+
- list[dict]: A list of data samples from various data systems.
528+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
529+
originating from the `i`-th data system.
530+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
531+
only when needed. Since the sampling process can be slow and memory-intensive,
532+
the lazy function helps by only sampling once.
533+
"""
534+
pass
535+
518536
def _get_forward_wrapper_func(self) -> Callable[..., paddle.Tensor]:
519537
"""Get a forward wrapper of the atomic model for output bias calculation."""
520538

deepmd/pd/model/atomic_model/dp_atomic_model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,30 @@ def wrapped_sampler():
397397
return sampled
398398

399399
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
400-
self.fitting_net.compute_input_stats(
401-
wrapped_sampler, protection=self.data_stat_protect
402-
)
400+
self.compute_fitting_input_stat(wrapped_sampler)
403401
if compute_or_load_out_stat:
404402
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
405403

404+
def compute_fitting_input_stat(
405+
self,
406+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
407+
) -> None:
408+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
409+
410+
Parameters
411+
----------
412+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
413+
- list[dict]: A list of data samples from various data systems.
414+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
415+
originating from the `i`-th data system.
416+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
417+
only when needed. Since the sampling process can be slow and memory-intensive,
418+
the lazy function helps by only sampling once.
419+
"""
420+
self.fitting_net.compute_input_stats(
421+
sample_merged, protection=self.data_stat_protect
422+
)
423+
406424
def get_dim_fparam(self) -> int:
407425
"""Get the number (dimension) of frame parameters of this atomic model."""
408426
return self.fitting_net.get_dim_fparam()

deepmd/pd/model/model/make_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def change_out_bias(
228228
merged,
229229
bias_adjust_mode=bias_adjust_mode,
230230
)
231+
if bias_adjust_mode == "set-by-statistic":
232+
self.atomic_model.compute_fitting_input_stat(merged)
231233

232234
def forward_common_lower(
233235
self,

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,24 @@ def change_out_bias(
493493
else:
494494
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
495495

496+
def compute_fitting_input_stat(
497+
self,
498+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
499+
) -> None:
500+
"""Compute the input statistics (e.g. mean and stddev) for the atomic model from packed data.
501+
502+
Parameters
503+
----------
504+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
505+
- list[dict]: A list of data samples from various data systems.
506+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
507+
originating from the `i`-th data system.
508+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
509+
only when needed. Since the sampling process can be slow and memory-intensive,
510+
the lazy function helps by only sampling once.
511+
"""
512+
pass
513+
496514
def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
497515
"""Get a forward wrapper of the atomic model for output bias calculation."""
498516

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Callable,
77
Optional,
8+
Union,
89
)
910

1011
import torch
@@ -337,14 +338,30 @@ def wrapped_sampler() -> list[dict]:
337338
return sampled
338339

339340
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
340-
self.fitting_net.compute_input_stats(
341-
wrapped_sampler,
342-
protection=self.data_stat_protect,
343-
stat_file_path=stat_file_path,
344-
)
341+
self.compute_fitting_input_stat(wrapped_sampler)
345342
if compute_or_load_out_stat:
346343
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
347344

345+
def compute_fitting_input_stat(
346+
self,
347+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
348+
) -> None:
349+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
350+
351+
Parameters
352+
----------
353+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
354+
- list[dict]: A list of data samples from various data systems.
355+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
356+
originating from the `i`-th data system.
357+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
358+
only when needed. Since the sampling process can be slow and memory-intensive,
359+
the lazy function helps by only sampling once.
360+
"""
361+
self.fitting_net.compute_input_stats(
362+
sample_merged, protection=self.data_stat_protect
363+
)
364+
348365
def get_dim_fparam(self) -> int:
349366
"""Get the number (dimension) of frame parameters of this atomic model."""
350367
return self.fitting_net.get_dim_fparam()

deepmd/pt/model/model/make_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def change_out_bias(
232232
merged,
233233
bias_adjust_mode=bias_adjust_mode,
234234
)
235+
if bias_adjust_mode == "set-by-statistic":
236+
self.atomic_model.compute_fitting_input_stat(merged)
235237

236238
def forward_common_lower(
237239
self,
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
4+
import numpy as np
5+
6+
from deepmd.dpmodel.descriptor import (
7+
DescrptSeA,
8+
)
9+
from deepmd.dpmodel.fitting import (
10+
EnergyFittingNet,
11+
)
12+
13+
14+
def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds):
15+
merged_output_stat = []
16+
nsys = len(sys_natoms)
17+
ndof = len(avgs)
18+
for ii in range(nsys):
19+
sys_dict = {}
20+
tmp_data_f = []
21+
tmp_data_a = []
22+
for jj in range(ndof):
23+
rng = np.random.default_rng(2025 * ii + 220 * jj)
24+
tmp_data_f.append(
25+
rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1))
26+
)
27+
rng = np.random.default_rng(220 * ii + 1636 * jj)
28+
tmp_data_a.append(
29+
rng.normal(
30+
loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii])
31+
)
32+
)
33+
tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0))
34+
tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0))
35+
sys_dict["fparam"] = tmp_data_f
36+
sys_dict["aparam"] = tmp_data_a
37+
merged_output_stat.append(sys_dict)
38+
return merged_output_stat
39+
40+
41+
def _brute_fparam_pt(data, ndim):
42+
adata = [ii["fparam"] for ii in data]
43+
all_data = []
44+
for ii in adata:
45+
tmp = np.reshape(ii, [-1, ndim])
46+
if len(all_data) == 0:
47+
all_data = np.array(tmp)
48+
else:
49+
all_data = np.concatenate((all_data, tmp), axis=0)
50+
avg = np.average(all_data, axis=0)
51+
std = np.std(all_data, axis=0)
52+
return avg, std
53+
54+
55+
def _brute_aparam_pt(data, ndim):
56+
adata = [ii["aparam"] for ii in data]
57+
all_data = []
58+
for ii in adata:
59+
tmp = np.reshape(ii, [-1, ndim])
60+
if len(all_data) == 0:
61+
all_data = np.array(tmp)
62+
else:
63+
all_data = np.concatenate((all_data, tmp), axis=0)
64+
avg = np.average(all_data, axis=0)
65+
std = np.std(all_data, axis=0)
66+
return avg, std
67+
68+
69+
class TestEnerFittingStat(unittest.TestCase):
70+
def test(self) -> None:
71+
descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
72+
fitting = EnergyFittingNet(
73+
descrpt.get_ntypes(),
74+
descrpt.get_dim_out(),
75+
neuron=[240, 240, 240],
76+
resnet_dt=True,
77+
numb_fparam=3,
78+
numb_aparam=3,
79+
)
80+
avgs = [0, 10, 100]
81+
stds = [2, 0.4, 0.00001]
82+
sys_natoms = [10, 100]
83+
sys_nframes = [5, 2]
84+
all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
85+
frefa, frefs = _brute_fparam_pt(all_data, len(avgs))
86+
arefa, arefs = _brute_aparam_pt(all_data, len(avgs))
87+
fitting.compute_input_stats(all_data, protection=1e-2)
88+
frefs_inv = 1.0 / frefs
89+
arefs_inv = 1.0 / arefs
90+
frefs_inv[frefs_inv > 100] = 100
91+
arefs_inv[arefs_inv > 100] = 100
92+
np.testing.assert_almost_equal(frefa, fitting.fparam_avg)
93+
np.testing.assert_almost_equal(frefs_inv, fitting.fparam_inv_std)
94+
np.testing.assert_almost_equal(arefa, fitting.aparam_avg)
95+
np.testing.assert_almost_equal(arefs_inv, fitting.aparam_inv_std)

source/tests/pd/test_training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ def test_dp_train(self) -> None:
8989
state_dict_trained[state_key].numpy(),
9090
state_dict_finetuned_empty[state_key].numpy(),
9191
)
92-
if "fitting_net" not in state_key:
92+
if (
93+
("fitting_net" not in state_key)
94+
or ("fparam" in state_key)
95+
or ("aparam" in state_key)
96+
):
9397
np.testing.assert_allclose(
9498
state_dict_trained[state_key].numpy(),
9599
state_dict_finetuned_random[state_key].numpy(),
@@ -190,6 +194,7 @@ def setUp(self) -> None:
190194
self.config["training"]["save_freq"] = 1
191195
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
192196
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
197+
self.config["model"]["data_stat_nbatch"] = 100
193198

194199
def tearDown(self) -> None:
195200
(self.set_path / "fparam.npy").unlink(missing_ok=True)

source/tests/pt/test_training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ def test_dp_train(self) -> None:
9292
state_dict_trained[state_key],
9393
state_dict_finetuned_empty[state_key],
9494
)
95-
if "fitting_net" not in state_key:
95+
if (
96+
("fitting_net" not in state_key)
97+
or ("fparam" in state_key)
98+
or ("aparam" in state_key)
99+
):
96100
torch.testing.assert_close(
97101
state_dict_trained[state_key],
98102
state_dict_finetuned_random[state_key],
@@ -256,6 +260,7 @@ def setUp(self) -> None:
256260
self.config["training"]["save_freq"] = 1
257261
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
258262
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
263+
self.config["model"]["data_stat_nbatch"] = 100
259264

260265
def tearDown(self) -> None:
261266
(self.set_path / "fparam.npy").unlink(missing_ok=True)

0 commit comments

Comments
 (0)