Skip to content

Commit 512eeb6

Browse files
author
Han Wang
committed
feat(pt_expt): multi-task training support
1 parent 0a9b4b6 commit 512eeb6

35 files changed

+7705
-404
lines changed

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,14 @@ def get_rcut(self) -> float:
345345
"""Returns the cut-off radius."""
346346
return self.rcut
347347

348+
def get_rcut_smth(self) -> float:
349+
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
350+
return self.rcut_smth
351+
352+
def get_env_protection(self) -> float:
353+
"""Returns the protection of building environment matrix."""
354+
return self.env_protection
355+
348356
def get_nsel(self) -> int:
349357
"""Returns the number of selected atoms in the cut-off radius."""
350358
return sum(self.sel)

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def compute_input_stats(
255255
stat_file_path : Optional[DPPath]
256256
The path to the stat file.
257257
"""
258+
self._param_stats: dict[str, list[StatItem]] = {}
258259
if self.numb_fparam == 0 and self.numb_aparam == 0:
259260
# skip data statistics
260261
return
@@ -296,6 +297,7 @@ def compute_input_stats(
296297
self._save_param_stats_to_file(
297298
stat_file_path, "fparam", fparam_stats
298299
)
300+
self._param_stats["fparam"] = fparam_stats
299301
fparam_avg = np.array(
300302
[s.compute_avg() for s in fparam_stats], dtype=np.float64
301303
)
@@ -362,6 +364,7 @@ def compute_input_stats(
362364
self._save_param_stats_to_file(
363365
stat_file_path, "aparam", aparam_stats
364366
)
367+
self._param_stats["aparam"] = aparam_stats
365368
aparam_avg = np.array(
366369
[s.compute_avg() for s in aparam_stats], dtype=np.float64
367370
)
@@ -407,6 +410,10 @@ def _load_param_stats_from_file(
407410
for ii in range(numb)
408411
]
409412

413+
def get_param_stats(self) -> dict[str, list[StatItem]]:
414+
"""Get the stored fparam/aparam statistics (populated by compute_input_stats)."""
415+
return getattr(self, "_param_stats", {})
416+
410417
@abstractmethod
411418
def _net_out_dim(self) -> int:
412419
"""Set the FittingNet output dim."""
@@ -666,11 +673,7 @@ def _call_common(
666673
# check fparam dim, concate to input descriptor
667674
if self.numb_fparam > 0:
668675
assert fparam is not None, "fparam should not be None"
669-
if fparam.shape[-1] != self.numb_fparam:
670-
raise ValueError(
671-
f"get an input fparam of dim {fparam.shape[-1]}, "
672-
f"which is not consistent with {self.numb_fparam}."
673-
)
676+
fparam = xp.reshape(fparam, (nf, self.numb_fparam))
674677
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
675678
fparam = xp.tile(
676679
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
@@ -687,11 +690,6 @@ def _call_common(
687690
# check aparam dim, concate to input descriptor
688691
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
689692
assert aparam is not None, "aparam should not be None"
690-
if aparam.shape[-1] != self.numb_aparam:
691-
raise ValueError(
692-
f"get an input aparam of dim {aparam.shape[-1]}, "
693-
f"which is not consistent with {self.numb_aparam}."
694-
)
695693
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
696694
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
697695
xx = xp.concat(

deepmd/dpmodel/utils/env_mat_stat.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,75 @@
4040
)
4141

4242

43+
def merge_env_stat(
44+
base_obj: Union["Descriptor", "DescriptorBlock"],
45+
link_obj: Union["Descriptor", "DescriptorBlock"],
46+
model_prob: float = 1.0,
47+
) -> None:
48+
"""Merge descriptor env mat stats from link_obj into base_obj.
49+
50+
Uses probability-weighted merging: merged = base_stats + link_stats * model_prob,
51+
where model_prob = link_prob / base_prob.
52+
Mutates base_obj.stats for chaining (3+ models).
53+
54+
Parameters
55+
----------
56+
base_obj : Descriptor or DescriptorBlock
57+
The base descriptor whose stats will be updated.
58+
link_obj : Descriptor or DescriptorBlock
59+
The linked descriptor whose stats will be merged in.
60+
model_prob : float
61+
The probability weight ratio (link_prob / base_prob).
62+
"""
63+
if (
64+
getattr(base_obj, "stats", None) is None
65+
or getattr(link_obj, "stats", None) is None
66+
):
67+
return
68+
if getattr(base_obj, "set_stddev_constant", False) and getattr(
69+
base_obj, "set_davg_zero", False
70+
):
71+
return
72+
73+
# Weighted merge of StatItem objects
74+
base_stats = base_obj.stats
75+
link_stats = link_obj.stats
76+
merged_stats = {}
77+
for kk in base_stats:
78+
merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob
79+
80+
# Compute mean/stddev from merged stats
81+
base_env = EnvMatStatSe(base_obj)
82+
base_env.stats = merged_stats
83+
mean, stddev = base_env()
84+
85+
# Update base_obj stats for chaining
86+
base_obj.stats = merged_stats
87+
88+
# Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks)
89+
# mean/stddev are numpy arrays; convert to match the buffer's backend
90+
if hasattr(base_obj, "davg"):
91+
xp = array_api_compat.array_namespace(base_obj.dstd)
92+
device = array_api_compat.device(base_obj.dstd)
93+
if not getattr(base_obj, "set_davg_zero", False):
94+
base_obj.davg[...] = xp.asarray(
95+
mean, dtype=base_obj.davg.dtype, device=device
96+
)
97+
base_obj.dstd[...] = xp.asarray(
98+
stddev, dtype=base_obj.dstd.dtype, device=device
99+
)
100+
elif hasattr(base_obj, "mean"):
101+
xp = array_api_compat.array_namespace(base_obj.stddev)
102+
device = array_api_compat.device(base_obj.stddev)
103+
if not getattr(base_obj, "set_davg_zero", False):
104+
base_obj.mean[...] = xp.asarray(
105+
mean, dtype=base_obj.mean.dtype, device=device
106+
)
107+
base_obj.stddev[...] = xp.asarray(
108+
stddev, dtype=base_obj.stddev.dtype, device=device
109+
)
110+
111+
43112
class EnvMatStat(BaseEnvMatStat):
44113
def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]:
45114
"""Compute the statistics of the environment matrix for a single system.

deepmd/pt/model/task/fitting.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -779,11 +779,6 @@ def _forward_common(
779779
assert fparam is not None, "fparam should not be None"
780780
assert self.fparam_avg is not None
781781
assert self.fparam_inv_std is not None
782-
if fparam.shape[-1] != self.numb_fparam:
783-
raise ValueError(
784-
"get an input fparam of dim {fparam.shape[-1]}, ",
785-
"which is not consistent with {self.numb_fparam}.",
786-
)
787782
fparam = fparam.view([nf, self.numb_fparam])
788783
nb, _ = fparam.shape
789784
t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb)
@@ -804,11 +799,6 @@ def _forward_common(
804799
assert aparam is not None, "aparam should not be None"
805800
assert self.aparam_avg is not None
806801
assert self.aparam_inv_std is not None
807-
if aparam.shape[-1] != self.numb_aparam:
808-
raise ValueError(
809-
f"get an input aparam of dim {aparam.shape[-1]}, ",
810-
f"which is not consistent with {self.numb_aparam}.",
811-
)
812802
aparam = aparam.view([nf, -1, self.numb_aparam])
813803
nb, nloc, _ = aparam.shape
814804
t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc)

deepmd/pt_expt/descriptor/dpa1.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
cast_precision,
1010
)
1111
from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP
12+
from deepmd.dpmodel.utils.env_mat_stat import (
13+
merge_env_stat,
14+
)
1215
from deepmd.pt_expt.common import (
1316
torch_module,
1417
)
@@ -26,6 +29,31 @@
2629
class DescrptDPA1(DescrptDPA1DP):
2730
_update_sel_cls = UpdateSel
2831

32+
def share_params(
33+
self,
34+
base_class: Any,
35+
shared_level: int,
36+
model_prob: float = 1.0,
37+
resume: bool = False,
38+
) -> None:
39+
"""Share parameters with base_class for multi-task training.
40+
41+
Level 0: share type_embedding and se_atten (all modules and buffers).
42+
Level 1: share type_embedding only.
43+
"""
44+
assert self.__class__ == base_class.__class__, (
45+
"Only descriptors of the same type can share params!"
46+
)
47+
if shared_level == 0:
48+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
49+
if not resume:
50+
merge_env_stat(base_class.se_atten, self.se_atten, model_prob)
51+
self._modules["se_atten"] = base_class._modules["se_atten"]
52+
elif shared_level == 1:
53+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
54+
else:
55+
raise NotImplementedError
56+
2957
def enable_compression(
3058
self,
3159
min_nbor_dist: float,

deepmd/pt_expt/descriptor/dpa2.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
build_multiple_neighbor_list,
1515
get_multiple_nlist_key,
1616
)
17+
from deepmd.dpmodel.utils.env_mat_stat import (
18+
merge_env_stat,
19+
)
1720
from deepmd.pt_expt.common import (
1821
torch_module,
1922
)
@@ -30,6 +33,47 @@
3033
class DescrptDPA2(DescrptDPA2DP):
3134
_update_sel_cls = UpdateSel
3235

36+
def share_params(
37+
self,
38+
base_class: "DescrptDPA2",
39+
shared_level: int,
40+
model_prob: float = 1.0,
41+
resume: bool = False,
42+
) -> None:
43+
"""Share parameters with base_class for multi-task training.
44+
45+
Level 0: share type_embedding, repinit, repinit_three_body,
46+
g1_shape_tranform, and repformers.
47+
Level 1: share type_embedding only.
48+
"""
49+
assert self.__class__ == base_class.__class__, (
50+
"Only descriptors of the same type can share params!"
51+
)
52+
if shared_level == 0:
53+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
54+
if not resume:
55+
merge_env_stat(base_class.repinit, self.repinit, model_prob)
56+
if self.use_three_body and "repinit_three_body" in base_class._modules:
57+
merge_env_stat(
58+
base_class.repinit_three_body,
59+
self.repinit_three_body,
60+
model_prob,
61+
)
62+
merge_env_stat(base_class.repformers, self.repformers, model_prob)
63+
self._modules["repinit"] = base_class._modules["repinit"]
64+
if self.use_three_body and "repinit_three_body" in base_class._modules:
65+
self._modules["repinit_three_body"] = base_class._modules[
66+
"repinit_three_body"
67+
]
68+
self._modules["g1_shape_tranform"] = base_class._modules[
69+
"g1_shape_tranform"
70+
]
71+
self._modules["repformers"] = base_class._modules["repformers"]
72+
elif shared_level == 1:
73+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
74+
else:
75+
raise NotImplementedError
76+
3377
def enable_compression(
3478
self,
3579
min_nbor_dist: float,

deepmd/pt_expt/descriptor/dpa3.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

33
from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP
4+
from deepmd.dpmodel.utils.env_mat_stat import (
5+
merge_env_stat,
6+
)
47
from deepmd.pt_expt.common import (
58
torch_module,
69
)
@@ -16,3 +19,28 @@
1619
@torch_module
1720
class DescrptDPA3(DescrptDPA3DP):
1821
_update_sel_cls = UpdateSel
22+
23+
def share_params(
24+
self,
25+
base_class: "DescrptDPA3",
26+
shared_level: int,
27+
model_prob: float = 1.0,
28+
resume: bool = False,
29+
) -> None:
30+
"""Share parameters with base_class for multi-task training.
31+
32+
Level 0: share type_embedding and repflows.
33+
Level 1: share type_embedding only.
34+
"""
35+
assert self.__class__ == base_class.__class__, (
36+
"Only descriptors of the same type can share params!"
37+
)
38+
if shared_level == 0:
39+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
40+
if not resume:
41+
merge_env_stat(base_class.repflows, self.repflows, model_prob)
42+
self._modules["repflows"] = base_class._modules["repflows"]
43+
elif shared_level == 1:
44+
self._modules["type_embedding"] = base_class._modules["type_embedding"]
45+
else:
46+
raise NotImplementedError

deepmd/pt_expt/descriptor/hybrid.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
25

36
from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
47
from deepmd.pt_expt.common import (
@@ -12,4 +15,27 @@
1215
@BaseDescriptor.register("hybrid")
1316
@torch_module
1417
class DescrptHybrid(DescrptHybridDP):
15-
pass
18+
def share_params(
19+
self,
20+
base_class: Any,
21+
shared_level: int,
22+
model_prob: float = 1.0,
23+
resume: bool = False,
24+
) -> None:
25+
"""Share parameters with base_class for multi-task training.
26+
27+
Level 0: share all sub-descriptors.
28+
"""
29+
assert self.__class__ == base_class.__class__, (
30+
"Only descriptors of the same type can share params!"
31+
)
32+
if shared_level == 0:
33+
for ii, des in enumerate(self.descrpt_list):
34+
self.descrpt_list[ii].share_params(
35+
base_class.descrpt_list[ii],
36+
shared_level,
37+
model_prob=model_prob,
38+
resume=resume,
39+
)
40+
else:
41+
raise NotImplementedError

deepmd/pt_expt/descriptor/se_atten_v2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ class DescrptSeAttenV2(DescrptSeAttenV2DP):
2222

2323
_update_sel_cls = UpdateSel
2424

25+
def share_params(self, *args: Any, **kwargs: Any) -> None:
26+
from deepmd.pt_expt.descriptor.dpa1 import (
27+
DescrptDPA1,
28+
)
29+
30+
return DescrptDPA1.share_params(self, *args, **kwargs)
31+
2532
def enable_compression(self, *args: Any, **kwargs: Any) -> None:
2633
from deepmd.pt_expt.descriptor.dpa1 import (
2734
DescrptDPA1,

0 commit comments

Comments
 (0)