Skip to content

Commit e98dc5a

Browse files
fix(finetune): calculate fitting stat when using random fitting in finetuning process (#4928)
In finetuing process, the computation of fitting stat is skipped in previous code. There are two situations: 1. Finetuning from pretrained model's branch: it means pretrained model also has `fparam` or `aparam` which has the same meaning of finetuning task. The key `fparam_avg`/`fparam_inv_std`/ `aparam_avg`/`aparam_inv_std` load from the pretrained model. It is correct. 2. Finetuning using RANDOM fitting. The fitting stat should be calculated in this situation. But the computation of fitting stat is skipped now. There is some error. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** - Automatic computation of input statistics is now performed during bias adjustment in "set-by-statistic" mode; public API extended to support computing fitting input statistics. * **Tests** - Training tests now compare additional parameter categories to the random-finetuned baseline. - Added test-only batching configuration for data-statistics and new unit tests validating fitting input-statistics calculation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent da452d7 commit e98dc5a

File tree

10 files changed

+256
-8
lines changed

10 files changed

+256
-8
lines changed

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from typing import (
66
Any,
7+
Callable,
78
Optional,
89
Union,
910
)
@@ -221,6 +222,71 @@ def __init__(
221222
],
222223
)
223224

225+
def compute_input_stats(
226+
self,
227+
merged: Union[Callable[[], list[dict]], list[dict]],
228+
protection: float = 1e-2,
229+
) -> None:
230+
"""
231+
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
232+
233+
Parameters
234+
----------
235+
merged : Union[Callable[[], list[dict]], list[dict]]
236+
- list[dict]: A list of data samples from various data systems.
237+
Each element, `merged[i]`, is a data dictionary containing `keys`: `numpy.ndarray`
238+
originating from the `i`-th data system.
239+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
240+
only when needed. Since the sampling process can be slow and memory-intensive,
241+
the lazy function helps by only sampling once.
242+
protection : float
243+
Divided-by-zero protection
244+
"""
245+
if self.numb_fparam == 0 and self.numb_aparam == 0:
246+
# skip data statistics
247+
return
248+
if callable(merged):
249+
sampled = merged()
250+
else:
251+
sampled = merged
252+
# stat fparam
253+
if self.numb_fparam > 0:
254+
cat_data = np.concatenate([frame["fparam"] for frame in sampled], axis=0)
255+
cat_data = np.reshape(cat_data, [-1, self.numb_fparam])
256+
fparam_avg = np.mean(cat_data, axis=0)
257+
fparam_std = np.std(cat_data, axis=0, ddof=0) # ddof=0 for population std
258+
fparam_std = np.where(
259+
fparam_std < protection,
260+
np.array(protection, dtype=fparam_std.dtype),
261+
fparam_std,
262+
)
263+
fparam_inv_std = 1.0 / fparam_std
264+
self.fparam_avg = fparam_avg.astype(self.fparam_avg.dtype)
265+
self.fparam_inv_std = fparam_inv_std.astype(self.fparam_inv_std.dtype)
266+
# stat aparam
267+
if self.numb_aparam > 0:
268+
sys_sumv = []
269+
sys_sumv2 = []
270+
sys_sumn = []
271+
for ss_ in [frame["aparam"] for frame in sampled]:
272+
ss = np.reshape(ss_, [-1, self.numb_aparam])
273+
sys_sumv.append(np.sum(ss, axis=0))
274+
sys_sumv2.append(np.sum(ss * ss, axis=0))
275+
sys_sumn.append(ss.shape[0])
276+
sumv = np.sum(np.stack(sys_sumv), axis=0)
277+
sumv2 = np.sum(np.stack(sys_sumv2), axis=0)
278+
sumn = sum(sys_sumn)
279+
aparam_avg = sumv / sumn
280+
aparam_std = np.sqrt(sumv2 / sumn - (sumv / sumn) ** 2)
281+
aparam_std = np.where(
282+
aparam_std < protection,
283+
np.array(protection, dtype=aparam_std.dtype),
284+
aparam_std,
285+
)
286+
aparam_inv_std = 1.0 / aparam_std
287+
self.aparam_avg = aparam_avg.astype(self.aparam_avg.dtype)
288+
self.aparam_inv_std = aparam_inv_std.astype(self.aparam_inv_std.dtype)
289+
224290
@abstractmethod
225291
def _net_out_dim(self) -> int:
226292
"""Set the FittingNet output dim."""

deepmd/pd/model/atomic_model/base_atomic_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,24 @@ def change_out_bias(
515515
else:
516516
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
517517

518+
def compute_fitting_input_stat(
519+
self,
520+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
521+
) -> None:
522+
"""Compute the input statistics (e.g. mean and stddev) for the atomic model from packed data.
523+
524+
Parameters
525+
----------
526+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
527+
- list[dict]: A list of data samples from various data systems.
528+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
529+
originating from the `i`-th data system.
530+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
531+
only when needed. Since the sampling process can be slow and memory-intensive,
532+
the lazy function helps by only sampling once.
533+
"""
534+
pass
535+
518536
def _get_forward_wrapper_func(self) -> Callable[..., paddle.Tensor]:
519537
"""Get a forward wrapper of the atomic model for output bias calculation."""
520538

deepmd/pd/model/atomic_model/dp_atomic_model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,30 @@ def wrapped_sampler():
397397
return sampled
398398

399399
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
400-
self.fitting_net.compute_input_stats(
401-
wrapped_sampler, protection=self.data_stat_protect
402-
)
400+
self.compute_fitting_input_stat(wrapped_sampler)
403401
if compute_or_load_out_stat:
404402
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
405403

404+
def compute_fitting_input_stat(
405+
self,
406+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
407+
) -> None:
408+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
409+
410+
Parameters
411+
----------
412+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
413+
- list[dict]: A list of data samples from various data systems.
414+
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
415+
originating from the `i`-th data system.
416+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
417+
only when needed. Since the sampling process can be slow and memory-intensive,
418+
the lazy function helps by only sampling once.
419+
"""
420+
self.fitting_net.compute_input_stats(
421+
sample_merged, protection=self.data_stat_protect
422+
)
423+
406424
def get_dim_fparam(self) -> int:
407425
"""Get the number (dimension) of frame parameters of this atomic model."""
408426
return self.fitting_net.get_dim_fparam()

deepmd/pd/model/model/make_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def change_out_bias(
228228
merged,
229229
bias_adjust_mode=bias_adjust_mode,
230230
)
231+
if bias_adjust_mode == "set-by-statistic":
232+
self.atomic_model.compute_fitting_input_stat(merged)
231233

232234
def forward_common_lower(
233235
self,

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,24 @@ def change_out_bias(
493493
else:
494494
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
495495

496+
def compute_fitting_input_stat(
497+
self,
498+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
499+
) -> None:
500+
"""Compute the input statistics (e.g. mean and stddev) for the atomic model from packed data.
501+
502+
Parameters
503+
----------
504+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
505+
- list[dict]: A list of data samples from various data systems.
506+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
507+
originating from the `i`-th data system.
508+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
509+
only when needed. Since the sampling process can be slow and memory-intensive,
510+
the lazy function helps by only sampling once.
511+
"""
512+
pass
513+
496514
def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
497515
"""Get a forward wrapper of the atomic model for output bias calculation."""
498516

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Callable,
77
Optional,
8+
Union,
89
)
910

1011
import torch
@@ -328,12 +329,30 @@ def wrapped_sampler() -> list[dict]:
328329
return sampled
329330

330331
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
331-
self.fitting_net.compute_input_stats(
332-
wrapped_sampler, protection=self.data_stat_protect
333-
)
332+
self.compute_fitting_input_stat(wrapped_sampler)
334333
if compute_or_load_out_stat:
335334
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
336335

336+
def compute_fitting_input_stat(
337+
self,
338+
sample_merged: Union[Callable[[], list[dict]], list[dict]],
339+
) -> None:
340+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
341+
342+
Parameters
343+
----------
344+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
345+
- list[dict]: A list of data samples from various data systems.
346+
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
347+
originating from the `i`-th data system.
348+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
349+
only when needed. Since the sampling process can be slow and memory-intensive,
350+
the lazy function helps by only sampling once.
351+
"""
352+
self.fitting_net.compute_input_stats(
353+
sample_merged, protection=self.data_stat_protect
354+
)
355+
337356
def get_dim_fparam(self) -> int:
338357
"""Get the number (dimension) of frame parameters of this atomic model."""
339358
return self.fitting_net.get_dim_fparam()

deepmd/pt/model/model/make_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def change_out_bias(
232232
merged,
233233
bias_adjust_mode=bias_adjust_mode,
234234
)
235+
if bias_adjust_mode == "set-by-statistic":
236+
self.atomic_model.compute_fitting_input_stat(merged)
235237

236238
def forward_common_lower(
237239
self,
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
4+
import numpy as np
5+
6+
from deepmd.dpmodel.descriptor import (
7+
DescrptSeA,
8+
)
9+
from deepmd.dpmodel.fitting import (
10+
EnergyFittingNet,
11+
)
12+
13+
14+
def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds):
15+
merged_output_stat = []
16+
nsys = len(sys_natoms)
17+
ndof = len(avgs)
18+
for ii in range(nsys):
19+
sys_dict = {}
20+
tmp_data_f = []
21+
tmp_data_a = []
22+
for jj in range(ndof):
23+
rng = np.random.default_rng(2025 * ii + 220 * jj)
24+
tmp_data_f.append(
25+
rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1))
26+
)
27+
rng = np.random.default_rng(220 * ii + 1636 * jj)
28+
tmp_data_a.append(
29+
rng.normal(
30+
loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii])
31+
)
32+
)
33+
tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0))
34+
tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0))
35+
sys_dict["fparam"] = tmp_data_f
36+
sys_dict["aparam"] = tmp_data_a
37+
merged_output_stat.append(sys_dict)
38+
return merged_output_stat
39+
40+
41+
def _brute_fparam_pt(data, ndim):
42+
adata = [ii["fparam"] for ii in data]
43+
all_data = []
44+
for ii in adata:
45+
tmp = np.reshape(ii, [-1, ndim])
46+
if len(all_data) == 0:
47+
all_data = np.array(tmp)
48+
else:
49+
all_data = np.concatenate((all_data, tmp), axis=0)
50+
avg = np.average(all_data, axis=0)
51+
std = np.std(all_data, axis=0)
52+
return avg, std
53+
54+
55+
def _brute_aparam_pt(data, ndim):
56+
adata = [ii["aparam"] for ii in data]
57+
all_data = []
58+
for ii in adata:
59+
tmp = np.reshape(ii, [-1, ndim])
60+
if len(all_data) == 0:
61+
all_data = np.array(tmp)
62+
else:
63+
all_data = np.concatenate((all_data, tmp), axis=0)
64+
avg = np.average(all_data, axis=0)
65+
std = np.std(all_data, axis=0)
66+
return avg, std
67+
68+
69+
class TestEnerFittingStat(unittest.TestCase):
70+
def test(self) -> None:
71+
descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
72+
fitting = EnergyFittingNet(
73+
descrpt.get_ntypes(),
74+
descrpt.get_dim_out(),
75+
neuron=[240, 240, 240],
76+
resnet_dt=True,
77+
numb_fparam=3,
78+
numb_aparam=3,
79+
)
80+
avgs = [0, 10, 100]
81+
stds = [2, 0.4, 0.00001]
82+
sys_natoms = [10, 100]
83+
sys_nframes = [5, 2]
84+
all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
85+
frefa, frefs = _brute_fparam_pt(all_data, len(avgs))
86+
arefa, arefs = _brute_aparam_pt(all_data, len(avgs))
87+
fitting.compute_input_stats(all_data, protection=1e-2)
88+
frefs_inv = 1.0 / frefs
89+
arefs_inv = 1.0 / arefs
90+
frefs_inv[frefs_inv > 100] = 100
91+
arefs_inv[arefs_inv > 100] = 100
92+
np.testing.assert_almost_equal(frefa, fitting.fparam_avg)
93+
np.testing.assert_almost_equal(frefs_inv, fitting.fparam_inv_std)
94+
np.testing.assert_almost_equal(arefa, fitting.aparam_avg)
95+
np.testing.assert_almost_equal(arefs_inv, fitting.aparam_inv_std)

source/tests/pd/test_training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ def test_dp_train(self) -> None:
8989
state_dict_trained[state_key].numpy(),
9090
state_dict_finetuned_empty[state_key].numpy(),
9191
)
92-
if "fitting_net" not in state_key:
92+
if (
93+
("fitting_net" not in state_key)
94+
or ("fparam" in state_key)
95+
or ("aparam" in state_key)
96+
):
9397
np.testing.assert_allclose(
9498
state_dict_trained[state_key].numpy(),
9599
state_dict_finetuned_random[state_key].numpy(),
@@ -190,6 +194,7 @@ def setUp(self) -> None:
190194
self.config["training"]["save_freq"] = 1
191195
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
192196
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
197+
self.config["model"]["data_stat_nbatch"] = 100
193198

194199
def tearDown(self) -> None:
195200
(self.set_path / "fparam.npy").unlink(missing_ok=True)

source/tests/pt/test_training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ def test_dp_train(self) -> None:
9292
state_dict_trained[state_key],
9393
state_dict_finetuned_empty[state_key],
9494
)
95-
if "fitting_net" not in state_key:
95+
if (
96+
("fitting_net" not in state_key)
97+
or ("fparam" in state_key)
98+
or ("aparam" in state_key)
99+
):
96100
torch.testing.assert_close(
97101
state_dict_trained[state_key],
98102
state_dict_finetuned_random[state_key],
@@ -256,6 +260,7 @@ def setUp(self) -> None:
256260
self.config["training"]["save_freq"] = 1
257261
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
258262
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
263+
self.config["model"]["data_stat_nbatch"] = 100
259264

260265
def tearDown(self) -> None:
261266
(self.set_path / "fparam.npy").unlink(missing_ok=True)

0 commit comments

Comments
 (0)