Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut

def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
return self.rcut_smth

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection

def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return sum(self.sel)
Expand Down
31 changes: 21 additions & 10 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def compute_input_stats(
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
self._param_stats: dict[str, list[StatItem]] = {}
if self.numb_fparam == 0 and self.numb_aparam == 0:
# skip data statistics
return
Expand Down Expand Up @@ -296,6 +297,7 @@ def compute_input_stats(
self._save_param_stats_to_file(
stat_file_path, "fparam", fparam_stats
)
self._param_stats["fparam"] = fparam_stats
fparam_avg = np.array(
[s.compute_avg() for s in fparam_stats], dtype=np.float64
)
Expand Down Expand Up @@ -362,6 +364,7 @@ def compute_input_stats(
self._save_param_stats_to_file(
stat_file_path, "aparam", aparam_stats
)
self._param_stats["aparam"] = aparam_stats
aparam_avg = np.array(
[s.compute_avg() for s in aparam_stats], dtype=np.float64
)
Expand Down Expand Up @@ -407,6 +410,10 @@ def _load_param_stats_from_file(
for ii in range(numb)
]

def get_param_stats(self) -> dict[str, list[StatItem]]:
"""Get the stored fparam/aparam statistics (populated by compute_input_stats)."""
return getattr(self, "_param_stats", {})

@abstractmethod
def _net_out_dim(self) -> int:
"""Set the FittingNet output dim."""
Expand Down Expand Up @@ -666,11 +673,13 @@ def _call_common(
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
assert fparam is not None, "fparam should not be None"
if fparam.shape[-1] != self.numb_fparam:
try:
fparam = xp.reshape(fparam, (nf, self.numb_fparam))
except (ValueError, RuntimeError) as e:
raise ValueError(
f"get an input fparam of dim {fparam.shape[-1]}, "
f"which is not consistent with {self.numb_fparam}."
)
f"input fparam: cannot reshape {fparam.shape} "
f"into ({nf}, {self.numb_fparam})."
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was valid at the time of review — the code previously used (-1, self.numb_fparam). Fixed in 80c714c which changed the reshape to (nf, self.numb_fparam), now matching the error message.

) from e
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
fparam = xp.tile(
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
Expand All @@ -687,12 +696,13 @@ def _call_common(
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
try:
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
except (ValueError, RuntimeError) as e:
raise ValueError(
f"get an input aparam of dim {aparam.shape[-1]}, "
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
f"input aparam: cannot reshape {aparam.shape} "
f"into ({nf}, {nloc}, {self.numb_aparam})."
) from e
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
xx = xp.concat(
[xx, aparam],
Expand Down Expand Up @@ -735,7 +745,8 @@ def _call_common(
)
for type_i in range(self.ntypes):
mask = xp.tile(
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
xp.reshape((atype == type_i), (nf, nloc, 1)),
(1, 1, net_dim_out),
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
Expand Down
8 changes: 6 additions & 2 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def _make_env_mat(
xp = array_api_compat.array_namespace(nlist)
nf, nloc, nnei = nlist.shape
# nf x nall x 3
coord = xp.reshape(coord, (nf, -1, 3))
# Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise
# both to (nf, nall, 3) using shape-based inference so the concrete nf
# value is not baked into the reshape.
if coord.ndim == 2:
coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3))
mask = nlist >= 0
nlist = nlist * xp.astype(mask, nlist.dtype)
# nf x (nloc x nnei) x 3
Expand All @@ -77,7 +81,7 @@ def _make_env_mat(
# nf x nloc x nnei x 3
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
# nf x nloc x 1 x 3
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3))
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, nloc, 1, 3))
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
Expand Down
69 changes: 69 additions & 0 deletions deepmd/dpmodel/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,75 @@
)


def merge_env_stat(
base_obj: Union["Descriptor", "DescriptorBlock"],
link_obj: Union["Descriptor", "DescriptorBlock"],
model_prob: float = 1.0,
) -> None:
"""Merge descriptor env mat stats from link_obj into base_obj.

Uses probability-weighted merging: merged = base_stats + link_stats * model_prob,
where model_prob = link_prob / base_prob.
Mutates base_obj.stats for chaining (3+ models).

Parameters
----------
base_obj : Descriptor or DescriptorBlock
The base descriptor whose stats will be updated.
link_obj : Descriptor or DescriptorBlock
The linked descriptor whose stats will be merged in.
model_prob : float
The probability weight ratio (link_prob / base_prob).
"""
if (
getattr(base_obj, "stats", None) is None
or getattr(link_obj, "stats", None) is None
):
return
if getattr(base_obj, "set_stddev_constant", False) and getattr(
base_obj, "set_davg_zero", False
):
return

# Weighted merge of StatItem objects
base_stats = base_obj.stats
link_stats = link_obj.stats
merged_stats = {}
for kk in base_stats:
merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob

# Compute mean/stddev from merged stats
base_env = EnvMatStatSe(base_obj)
base_env.stats = merged_stats
mean, stddev = base_env()

# Update base_obj stats for chaining
base_obj.stats = merged_stats

# Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks)
# mean/stddev are numpy arrays; convert to match the buffer's backend
if hasattr(base_obj, "davg"):
xp = array_api_compat.array_namespace(base_obj.dstd)
device = array_api_compat.device(base_obj.dstd)
if not getattr(base_obj, "set_davg_zero", False):
base_obj.davg[...] = xp.asarray(
mean, dtype=base_obj.davg.dtype, device=device
)
base_obj.dstd[...] = xp.asarray(
stddev, dtype=base_obj.dstd.dtype, device=device
)
elif hasattr(base_obj, "mean"):
xp = array_api_compat.array_namespace(base_obj.stddev)
device = array_api_compat.device(base_obj.stddev)
if not getattr(base_obj, "set_davg_zero", False):
base_obj.mean[...] = xp.asarray(
mean, dtype=base_obj.mean.dtype, device=device
)
base_obj.stddev[...] = xp.asarray(
stddev, dtype=base_obj.stddev.dtype, device=device
)


class EnvMatStat(BaseEnvMatStat):
def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]:
"""Compute the statistics of the environment matrix for a single system.
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,10 @@ def _forward_common(
assert fparam is not None, "fparam should not be None"
assert self.fparam_avg is not None
assert self.fparam_inv_std is not None
if fparam.shape[-1] != self.numb_fparam:
if fparam.numel() != nf * self.numb_fparam:
raise ValueError(
"get an input fparam of dim {fparam.shape[-1]}, ",
"which is not consistent with {self.numb_fparam}.",
f"input fparam: cannot reshape {list(fparam.shape)} "
f"into ({nf}, {self.numb_fparam})."
)
fparam = fparam.view([nf, self.numb_fparam])
nb, _ = fparam.shape
Expand All @@ -804,10 +804,10 @@ def _forward_common(
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
if aparam.shape[-1] != self.numb_aparam:
if aparam.numel() % (nf * self.numb_aparam) != 0:
raise ValueError(
f"get an input aparam of dim {aparam.shape[-1]}, ",
f"which is not consistent with {self.numb_aparam}.",
f"input aparam: cannot reshape {list(aparam.shape)} "
f"into ({nf}, nloc, {self.numb_aparam})."
)
aparam = aparam.view([nf, -1, self.numb_aparam])
nb, nloc, _ = aparam.shape
Comment on lines 779 to 813
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The removed checks were buggy (non-f-string + tuple ValueError), but the user-friendly error is worth keeping. Fixed in 6ae50db — added try/except wrapping to match the dpmodel pattern.

Expand Down
28 changes: 28 additions & 0 deletions deepmd/pt_expt/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
cast_precision,
)
from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
Expand All @@ -26,6 +29,31 @@
class DescrptDPA1(DescrptDPA1DP):
_update_sel_cls = UpdateSel

def share_params(
self,
base_class: Any,
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.
Level 0: share type_embedding and se_atten (all modules and buffers).
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.se_atten, self.se_atten, model_prob)
self._modules["se_atten"] = base_class._modules["se_atten"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
Expand Down
44 changes: 44 additions & 0 deletions deepmd/pt_expt/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
Expand All @@ -30,6 +33,47 @@
class DescrptDPA2(DescrptDPA2DP):
_update_sel_cls = UpdateSel

def share_params(
self,
base_class: "DescrptDPA2",
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.

Level 0: share type_embedding, repinit, repinit_three_body,
g1_shape_tranform, and repformers.
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.repinit, self.repinit, model_prob)
if self.use_three_body and "repinit_three_body" in base_class._modules:
merge_env_stat(
base_class.repinit_three_body,
self.repinit_three_body,
model_prob,
)
merge_env_stat(base_class.repformers, self.repformers, model_prob)
self._modules["repinit"] = base_class._modules["repinit"]
if self.use_three_body and "repinit_three_body" in base_class._modules:
self._modules["repinit_three_body"] = base_class._modules[
"repinit_three_body"
]
self._modules["g1_shape_tranform"] = base_class._modules[
"g1_shape_tranform"
]
self._modules["repformers"] = base_class._modules["repformers"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
Expand Down
28 changes: 28 additions & 0 deletions deepmd/pt_expt/descriptor/dpa3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP
from deepmd.dpmodel.utils.env_mat_stat import (
merge_env_stat,
)
from deepmd.pt_expt.common import (
torch_module,
)
Expand All @@ -16,3 +19,28 @@
@torch_module
class DescrptDPA3(DescrptDPA3DP):
_update_sel_cls = UpdateSel

def share_params(
self,
base_class: "DescrptDPA3",
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.

Level 0: share type_embedding and repflows.
Level 1: share type_embedding only.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
if not resume:
merge_env_stat(base_class.repflows, self.repflows, model_prob)
self._modules["repflows"] = base_class._modules["repflows"]
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
else:
raise NotImplementedError
28 changes: 27 additions & 1 deletion deepmd/pt_expt/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
from deepmd.pt_expt.common import (
Expand All @@ -12,4 +15,27 @@
@BaseDescriptor.register("hybrid")
@torch_module
class DescrptHybrid(DescrptHybridDP):
pass
def share_params(
self,
base_class: Any,
shared_level: int,
model_prob: float = 1.0,
resume: bool = False,
) -> None:
"""Share parameters with base_class for multi-task training.

Level 0: share all sub-descriptors.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
if shared_level == 0:
for ii, des in enumerate(self.descrpt_list):
self.descrpt_list[ii].share_params(
base_class.descrpt_list[ii],
shared_level,
model_prob=model_prob,
resume=resume,
)
else:
raise NotImplementedError
Loading
Loading