diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 774cf33d72..5881b3a0b3 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -345,6 +345,14 @@ def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + def get_nsel(self) -> int: """Returns the number of selected atoms in the cut-off radius.""" return sum(self.sel) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 3a3012440c..b9129a4364 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -255,6 +255,7 @@ def compute_input_stats( stat_file_path : Optional[DPPath] The path to the stat file. """ + self._param_stats: dict[str, list[StatItem]] = {} if self.numb_fparam == 0 and self.numb_aparam == 0: # skip data statistics return @@ -296,6 +297,7 @@ def compute_input_stats( self._save_param_stats_to_file( stat_file_path, "fparam", fparam_stats ) + self._param_stats["fparam"] = fparam_stats fparam_avg = np.array( [s.compute_avg() for s in fparam_stats], dtype=np.float64 ) @@ -362,6 +364,7 @@ def compute_input_stats( self._save_param_stats_to_file( stat_file_path, "aparam", aparam_stats ) + self._param_stats["aparam"] = aparam_stats aparam_avg = np.array( [s.compute_avg() for s in aparam_stats], dtype=np.float64 ) @@ -407,6 +410,10 @@ def _load_param_stats_from_file( for ii in range(numb) ] + def get_param_stats(self) -> dict[str, list[StatItem]]: + """Get the stored fparam/aparam statistics (populated by compute_input_stats).""" + return getattr(self, "_param_stats", {}) + @abstractmethod def _net_out_dim(self) -> int: """Set the FittingNet output dim.""" @@ -666,11 +673,13 @@ def _call_common( # check fparam dim, concate to input descriptor if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" - if fparam.shape[-1] != self.numb_fparam: + try: + fparam = xp.reshape(fparam, (nf, self.numb_fparam)) + except (ValueError, RuntimeError) as e: raise ValueError( - f"get an input fparam of dim {fparam.shape[-1]}, " - f"which is not consistent with {self.numb_fparam}." - ) + f"input fparam: cannot reshape {fparam.shape} " + f"into ({nf}, {self.numb_fparam})." + ) from e fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1) @@ -687,12 +696,13 @@ def _call_common( # check aparam dim, concate to input descriptor if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" - if aparam.shape[-1] != self.numb_aparam: + try: + aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + except (ValueError, RuntimeError) as e: raise ValueError( - f"get an input aparam of dim {aparam.shape[-1]}, " - f"which is not consistent with {self.numb_aparam}." - ) - aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + f"input aparam: cannot reshape {aparam.shape} " + f"into ({nf}, {nloc}, {self.numb_aparam})." + ) from e aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...] xx = xp.concat( [xx, aparam], @@ -735,7 +745,8 @@ def _call_common( ) for type_i in range(self.ntypes): mask = xp.tile( - xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out) + xp.reshape((atype == type_i), (nf, nloc, 1)), + (1, 1, net_dim_out), ) atom_property = self.nets[(type_i,)](xx) if self.remove_vaccum_contribution is not None and not ( diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 9856741317..0b0bd18c35 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -68,7 +68,11 @@ def _make_env_mat( xp = array_api_compat.array_namespace(nlist) nf, nloc, nnei = nlist.shape # nf x nall x 3 - coord = xp.reshape(coord, (nf, -1, 3)) + # Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise + # both to (nf, nall, 3) using shape-based inference so the concrete nf + # value is not baked into the reshape. + if coord.ndim == 2: + coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3)) mask = nlist >= 0 nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 @@ -77,7 +81,7 @@ def _make_env_mat( # nf x nloc x nnei x 3 coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3)) + coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, nloc, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index 721723821e..8d53602b18 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -40,6 +40,75 @@ ) +def merge_env_stat( + base_obj: Union["Descriptor", "DescriptorBlock"], + link_obj: Union["Descriptor", "DescriptorBlock"], + model_prob: float = 1.0, +) -> None: + """Merge descriptor env mat stats from link_obj into base_obj. + + Uses probability-weighted merging: merged = base_stats + link_stats * model_prob, + where model_prob = link_prob / base_prob. + Mutates base_obj.stats for chaining (3+ models). + + Parameters + ---------- + base_obj : Descriptor or DescriptorBlock + The base descriptor whose stats will be updated. + link_obj : Descriptor or DescriptorBlock + The linked descriptor whose stats will be merged in. + model_prob : float + The probability weight ratio (link_prob / base_prob). + """ + if ( + getattr(base_obj, "stats", None) is None + or getattr(link_obj, "stats", None) is None + ): + return + if getattr(base_obj, "set_stddev_constant", False) and getattr( + base_obj, "set_davg_zero", False + ): + return + + # Weighted merge of StatItem objects + base_stats = base_obj.stats + link_stats = link_obj.stats + merged_stats = {} + for kk in base_stats: + merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob + + # Compute mean/stddev from merged stats + base_env = EnvMatStatSe(base_obj) + base_env.stats = merged_stats + mean, stddev = base_env() + + # Update base_obj stats for chaining + base_obj.stats = merged_stats + + # Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks) + # mean/stddev are numpy arrays; convert to match the buffer's backend + if hasattr(base_obj, "davg"): + xp = array_api_compat.array_namespace(base_obj.dstd) + device = array_api_compat.device(base_obj.dstd) + if not getattr(base_obj, "set_davg_zero", False): + base_obj.davg[...] = xp.asarray( + mean, dtype=base_obj.davg.dtype, device=device + ) + base_obj.dstd[...] = xp.asarray( + stddev, dtype=base_obj.dstd.dtype, device=device + ) + elif hasattr(base_obj, "mean"): + xp = array_api_compat.array_namespace(base_obj.stddev) + device = array_api_compat.device(base_obj.stddev) + if not getattr(base_obj, "set_davg_zero", False): + base_obj.mean[...] = xp.asarray( + mean, dtype=base_obj.mean.dtype, device=device + ) + base_obj.stddev[...] = xp.asarray( + stddev, dtype=base_obj.stddev.dtype, device=device + ) + + class EnvMatStat(BaseEnvMatStat): def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]: """Compute the statistics of the environment matrix for a single system. diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 7aac7b9a29..439d3d11d9 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -779,10 +779,10 @@ def _forward_common( assert fparam is not None, "fparam should not be None" assert self.fparam_avg is not None assert self.fparam_inv_std is not None - if fparam.shape[-1] != self.numb_fparam: + if fparam.numel() != nf * self.numb_fparam: raise ValueError( - "get an input fparam of dim {fparam.shape[-1]}, ", - "which is not consistent with {self.numb_fparam}.", + f"input fparam: cannot reshape {list(fparam.shape)} " + f"into ({nf}, {self.numb_fparam})." ) fparam = fparam.view([nf, self.numb_fparam]) nb, _ = fparam.shape @@ -804,10 +804,10 @@ def _forward_common( assert aparam is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None - if aparam.shape[-1] != self.numb_aparam: + if aparam.numel() % (nf * self.numb_aparam) != 0: raise ValueError( - f"get an input aparam of dim {aparam.shape[-1]}, ", - f"which is not consistent with {self.numb_aparam}.", + f"input aparam: cannot reshape {list(aparam.shape)} " + f"into ({nf}, nloc, {self.numb_aparam})." ) aparam = aparam.view([nf, -1, self.numb_aparam]) nb, nloc, _ = aparam.shape diff --git a/deepmd/pt_expt/descriptor/dpa1.py b/deepmd/pt_expt/descriptor/dpa1.py index d72a12267a..01df91abd6 100644 --- a/deepmd/pt_expt/descriptor/dpa1.py +++ b/deepmd/pt_expt/descriptor/dpa1.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -26,6 +29,31 @@ class DescrptDPA1(DescrptDPA1DP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding and se_atten (all modules and buffers). + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.se_atten, self.se_atten, model_prob) + self._modules["se_atten"] = base_class._modules["se_atten"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py index 0d389af070..1723df5a30 100644 --- a/deepmd/pt_expt/descriptor/dpa2.py +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -14,6 +14,9 @@ build_multiple_neighbor_list, get_multiple_nlist_key, ) +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -30,6 +33,47 @@ class DescrptDPA2(DescrptDPA2DP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: "DescrptDPA2", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding, repinit, repinit_three_body, + g1_shape_tranform, and repformers. + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.repinit, self.repinit, model_prob) + if self.use_three_body and "repinit_three_body" in base_class._modules: + merge_env_stat( + base_class.repinit_three_body, + self.repinit_three_body, + model_prob, + ) + merge_env_stat(base_class.repformers, self.repformers, model_prob) + self._modules["repinit"] = base_class._modules["repinit"] + if self.use_three_body and "repinit_three_body" in base_class._modules: + self._modules["repinit_three_body"] = base_class._modules[ + "repinit_three_body" + ] + self._modules["g1_shape_tranform"] = base_class._modules[ + "g1_shape_tranform" + ] + self._modules["repformers"] = base_class._modules["repformers"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/dpa3.py b/deepmd/pt_expt/descriptor/dpa3.py index 7119f043bd..fec047e3fd 100644 --- a/deepmd/pt_expt/descriptor/dpa3.py +++ b/deepmd/pt_expt/descriptor/dpa3.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -16,3 +19,28 @@ @torch_module class DescrptDPA3(DescrptDPA3DP): _update_sel_cls = UpdateSel + + def share_params( + self, + base_class: "DescrptDPA3", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding and repflows. + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.repflows, self.repflows, model_prob) + self._modules["repflows"] = base_class._modules["repflows"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError diff --git a/deepmd/pt_expt/descriptor/hybrid.py b/deepmd/pt_expt/descriptor/hybrid.py index 9ec5570c7c..07eddd2e01 100644 --- a/deepmd/pt_expt/descriptor/hybrid.py +++ b/deepmd/pt_expt/descriptor/hybrid.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP from deepmd.pt_expt.common import ( @@ -12,4 +15,27 @@ @BaseDescriptor.register("hybrid") @torch_module class DescrptHybrid(DescrptHybridDP): - pass + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all sub-descriptors. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + for ii, des in enumerate(self.descrpt_list): + self.descrpt_list[ii].share_params( + base_class.descrpt_list[ii], + shared_level, + model_prob=model_prob, + resume=resume, + ) + else: + raise NotImplementedError diff --git a/deepmd/pt_expt/descriptor/se_atten_v2.py b/deepmd/pt_expt/descriptor/se_atten_v2.py index 2c4be7d3ae..e0eb3acac3 100644 --- a/deepmd/pt_expt/descriptor/se_atten_v2.py +++ b/deepmd/pt_expt/descriptor/se_atten_v2.py @@ -22,6 +22,13 @@ class DescrptSeAttenV2(DescrptSeAttenV2DP): _update_sel_cls = UpdateSel + def share_params(self, *args: Any, **kwargs: Any) -> None: + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + return DescrptDPA1.share_params(self, *args, **kwargs) + def enable_compression(self, *args: Any, **kwargs: Any) -> None: from deepmd.pt_expt.descriptor.dpa1 import ( DescrptDPA1, diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 38be83c46c..61d611036e 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -26,6 +29,30 @@ class DescrptSeA(DescrptSeADP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all modules and buffers. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, self, model_prob) + for item in self._modules: + self._modules[item] = base_class._modules[item] + for item in self._buffers: + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index c2fd34e6b5..22302f54e6 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -26,6 +29,30 @@ class DescrptSeR(DescrptSeRDP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all modules and buffers. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, self, model_prob) + for item in self._modules: + self._modules[item] = base_class._modules[item] + for item in self._buffers: + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 806d5eca7a..061306f281 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -27,6 +30,30 @@ class DescrptSeT(DescrptSeTDP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all modules and buffers. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, self, model_prob) + for item in self._modules: + self._modules[item] = base_class._modules[item] + for item in self._buffers: + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 385bf0dfb6..c0ae308971 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -25,6 +28,31 @@ class DescrptSeTTebd(DescrptSeTTebdDP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: "DescrptSeTTebd", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding and se_ttebd. + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.se_ttebd, self.se_ttebd, model_prob) + self._modules["se_ttebd"] = base_class._modules["se_ttebd"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index 3c82ff13aa..40302ee7b3 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -4,6 +4,7 @@ import argparse import json import logging +import os from pathlib import ( Path, ) @@ -40,54 +41,112 @@ def get_trainer( restart_model: str | None = None, finetune_model: str | None = None, finetune_links: dict | None = None, + shared_links: dict | None = None, ) -> training.Trainer: """Build a :class:`training.Trainer` from a normalised config.""" model_params = config["model"] training_params = config["training"] - type_map = model_params["type_map"] + multi_task = "model_dict" in model_params - # ----- training data ------------------------------------------------ - training_dataset_params = training_params["training_data"] - training_systems = process_systems( - training_dataset_params["systems"], - patterns=training_dataset_params.get("rglob_patterns", None), - ) - train_data = DeepmdDataSystem( - systems=training_systems, - batch_size=training_dataset_params["batch_size"], - test_size=1, - type_map=type_map, - trn_all_set=True, - sys_probs=training_dataset_params.get("sys_probs", None), - auto_prob_style=training_dataset_params.get("auto_prob", "prob_sys_size"), - ) + if not multi_task: + type_map = model_params["type_map"] - # ----- validation data ---------------------------------------------- - validation_data = None - validation_dataset_params = training_params.get("validation_data", None) - if validation_dataset_params is not None: - val_systems = process_systems( - validation_dataset_params["systems"], - patterns=validation_dataset_params.get("rglob_patterns", None), + # ----- training data ------------------------------------------------ + training_dataset_params = training_params["training_data"] + training_systems = process_systems( + training_dataset_params["systems"], + patterns=training_dataset_params.get("rglob_patterns", None), ) - validation_data = DeepmdDataSystem( - systems=val_systems, - batch_size=validation_dataset_params["batch_size"], + train_data = DeepmdDataSystem( + systems=training_systems, + batch_size=training_dataset_params["batch_size"], test_size=1, type_map=type_map, trn_all_set=True, + sys_probs=training_dataset_params.get("sys_probs", None), + auto_prob_style=training_dataset_params.get("auto_prob", "prob_sys_size"), ) - # ----- stat file path ----------------------------------------------- - stat_file_path = training_params.get("stat_file", None) - if stat_file_path is not None: - if not Path(stat_file_path).exists(): - if stat_file_path.endswith((".h5", ".hdf5")): - with h5py.File(stat_file_path, "w"): - pass + # ----- validation data ---------------------------------------------- + validation_data = None + validation_dataset_params = training_params.get("validation_data", None) + if validation_dataset_params is not None: + val_systems = process_systems( + validation_dataset_params["systems"], + patterns=validation_dataset_params.get("rglob_patterns", None), + ) + validation_data = DeepmdDataSystem( + systems=val_systems, + batch_size=validation_dataset_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + ) + + # ----- stat file path ----------------------------------------------- + stat_file_path = training_params.get("stat_file", None) + if stat_file_path is not None: + if not Path(stat_file_path).exists(): + if stat_file_path.endswith((".h5", ".hdf5")): + with h5py.File(stat_file_path, "w"): + pass + else: + Path(stat_file_path).mkdir(parents=True, exist_ok=True) + stat_file_path = DPPath(stat_file_path, "a") + else: + # Multi-task: build per-task data systems + train_data = {} + validation_data = {} + stat_file_path = {} + for model_key in model_params["model_dict"]: + type_map = model_params["model_dict"][model_key]["type_map"] + data_params = training_params["data_dict"][model_key] + + # training data + td_params = data_params["training_data"] + training_systems = process_systems( + td_params["systems"], + patterns=td_params.get("rglob_patterns", None), + ) + train_data[model_key] = DeepmdDataSystem( + systems=training_systems, + batch_size=td_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + sys_probs=td_params.get("sys_probs", None), + auto_prob_style=td_params.get("auto_prob", "prob_sys_size"), + ) + + # validation data + vd_params = data_params.get("validation_data", None) + if vd_params is not None: + val_systems = process_systems( + vd_params["systems"], + patterns=vd_params.get("rglob_patterns", None), + ) + validation_data[model_key] = DeepmdDataSystem( + systems=val_systems, + batch_size=vd_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + ) + else: + validation_data[model_key] = None + + # stat file + _sf = data_params.get("stat_file", None) + if _sf is not None: + if not Path(_sf).exists(): + if _sf.endswith((".h5", ".hdf5")): + with h5py.File(_sf, "w"): + pass + else: + Path(_sf).mkdir(parents=True, exist_ok=True) + stat_file_path[model_key] = DPPath(_sf, "a") else: - Path(stat_file_path).mkdir() - stat_file_path = DPPath(stat_file_path, "a") + stat_file_path[model_key] = None trainer = training.Trainer( config, @@ -98,6 +157,7 @@ def get_trainer( restart_model=restart_model, finetune_model=finetune_model, finetune_links=finetune_links, + shared_links=shared_links, ) return trainer @@ -151,6 +211,19 @@ def train( if restart is not None and not restart.endswith(".pt"): restart += ".pt" + # Multi-task detection and shared params preprocessing + multi_task = "model_dict" in config.get("model", {}) + shared_links = None + if multi_task: + from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, + ) + + config["model"], shared_links = preprocess_shared_params(config["model"]) + assert "RANDOM" not in config["model"]["model_dict"], ( + "Model name can not be 'RANDOM' in multi-task mode!" + ) + # update fine-tuning config finetune_links = None if finetune is not None: @@ -174,7 +247,7 @@ def train( # argcheck config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") - config = normalize(config) + config = normalize(config, multi_task=multi_task) # neighbour stat if not skip_neighbor_stat: @@ -182,27 +255,54 @@ def train( "Calculate neighbor statistics... " "(add --skip-neighbor-stat to skip this step)" ) - type_map = config["model"].get("type_map") - train_data = get_data(config["training"]["training_data"], 0, type_map, None) from deepmd.pt_expt.model import ( BaseModel, ) - config["model"], _min_nbor_dist = BaseModel.update_sel( - train_data, type_map, config["model"] - ) + if not multi_task: + type_map = config["model"].get("type_map") + train_data = get_data( + config["training"]["training_data"], 0, type_map, None + ) + config["model"], _ = BaseModel.update_sel( + train_data, type_map, config["model"] + ) + else: + for model_key in config["model"]["model_dict"]: + type_map = config["model"]["model_dict"][model_key]["type_map"] + train_data = get_data( + config["training"]["data_dict"][model_key]["training_data"], + 0, + type_map, + None, + ) + config["model"]["model_dict"][model_key], _ = BaseModel.update_sel( + train_data, + type_map, + config["model"]["model_dict"][model_key], + ) with open(output, "w") as fp: json.dump(config, fp, indent=4) - trainer = get_trainer( - config, - init_model, - restart, - finetune_model=finetune, - finetune_links=finetune_links, - ) - trainer.run() + import torch.distributed as dist + + if os.environ.get("LOCAL_RANK") is not None: + dist.init_process_group(backend="cuda:nccl,cpu:gloo") + + try: + trainer = get_trainer( + config, + init_model, + restart, + finetune_model=finetune, + finetune_links=finetune_links, + shared_links=shared_links, + ) + trainer.run() + finally: + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() def freeze( @@ -219,7 +319,7 @@ def freeze( output : str Path for the output .pte file. head : str or None - Head to freeze in multi-task mode (not yet supported). + Head to freeze in multi-task mode. """ import torch @@ -248,18 +348,43 @@ def freeze( ) model_params = extra_state["model_params"] - if head is not None and "model_dict" in model_params: - raise NotImplementedError( - "Multi-task freeze is not yet supported for the pt_expt backend." - ) + multi_task = "model_dict" in model_params + if multi_task: + if head is None: + raise ValueError( + "Multi-task model requires --head to specify which model to freeze. " + f"Available heads: {list(model_params['model_dict'].keys())}" + ) + if head not in model_params["model_dict"]: + raise ValueError( + f"Head '{head}' not found. " + f"Available: {list(model_params['model_dict'].keys())}" + ) + # Build full multi-task wrapper, load weights, extract single head + model_dict = {} + for key in model_params["model_dict"]: + from copy import ( + deepcopy, + ) - m = get_model(model_params) - wrapper = ModelWrapper(m) - wrapper.load_state_dict(state_dict) - m.eval() + model_dict[key] = get_model(deepcopy(model_params["model_dict"][key])) + wrapper = ModelWrapper(model_dict) + wrapper.load_state_dict(state_dict) + + m = wrapper.model[head] + single_model_params = model_params["model_dict"][head] + else: + m = get_model(model_params) + wrapper = ModelWrapper(m) + wrapper.load_state_dict(state_dict) + single_model_params = model_params - model_dict = m.serialize() - deserialize_to_file(output, {"model": model_dict, "model_def_script": model_params}) + m.eval() + model_dict_serialized = m.serialize() + deserialize_to_file( + output, + {"model": model_dict_serialized, "model_def_script": single_model_params}, + ) log.info("Saved frozen model to %s", output) diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index f778af8fec..e931c72b5c 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( @@ -13,4 +16,9 @@ @BaseFitting.register("ener") @torch_module class EnergyFittingNet(EnergyFittingNetDP): - pass + def share_params(self, *args: Any, **kwargs: Any) -> None: + from deepmd.pt_expt.fitting.invar_fitting import ( + InvarFitting, + ) + + return InvarFitting.share_params(self, *args, **kwargs) diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index f13fe2afbb..35eed09bc3 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -1,4 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import numpy as np +import torch from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( @@ -7,9 +13,122 @@ from deepmd.pt_expt.fitting.base_fitting import ( BaseFitting, ) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) @BaseFitting.register("invar") @torch_module class InvarFitting(InvarFittingDP): - pass + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + protection: float = 1e-2, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all sub-modules and buffers except bias_atom_e + and case_embd. When not resuming, fparam/aparam statistics are + merged using probability-weighted averaging (matching PT). + """ + assert self.__class__ == base_class.__class__, ( + "Only fitting nets of the same type can share params!" + ) + if shared_level == 0: + # --- weighted fparam stat merging --- + if self.numb_fparam > 0: + if not resume: + base_stats = base_class.get_param_stats().get("fparam", []) + self_stats = self.get_param_stats().get("fparam", []) + if base_stats and self_stats: + assert len(base_stats) == self.numb_fparam + merged = [ + base_stats[ii] + self_stats[ii] * model_prob + for ii in range(self.numb_fparam) + ] + fparam_avg = np.array( + [s.compute_avg() for s in merged], dtype=np.float64 + ) + fparam_std = np.array( + [s.compute_std(protection=protection) for s in merged], + dtype=np.float64, + ) + fparam_inv_std = 1.0 / fparam_std + base_class.fparam_avg.copy_( + torch.tensor( + fparam_avg, + device=DEVICE, + dtype=base_class.fparam_avg.dtype, + ) + ) + base_class.fparam_inv_std.copy_( + torch.tensor( + fparam_inv_std, + device=DEVICE, + dtype=base_class.fparam_inv_std.dtype, + ) + ) + # update stored stats so chained share_params works + base_class._param_stats["fparam"] = merged + self._buffers["fparam_avg"] = base_class._buffers["fparam_avg"] + self._buffers["fparam_inv_std"] = base_class._buffers["fparam_inv_std"] + + # --- weighted aparam stat merging --- + if self.numb_aparam > 0: + if not resume: + base_stats = base_class.get_param_stats().get("aparam", []) + self_stats = self.get_param_stats().get("aparam", []) + if base_stats and self_stats: + assert len(base_stats) == self.numb_aparam + merged = [ + base_stats[ii] + self_stats[ii] * model_prob + for ii in range(self.numb_aparam) + ] + aparam_avg = np.array( + [s.compute_avg() for s in merged], dtype=np.float64 + ) + aparam_std = np.array( + [s.compute_std(protection=protection) for s in merged], + dtype=np.float64, + ) + aparam_inv_std = 1.0 / aparam_std + base_class.aparam_avg.copy_( + torch.tensor( + aparam_avg, + device=DEVICE, + dtype=base_class.aparam_avg.dtype, + ) + ) + base_class.aparam_inv_std.copy_( + torch.tensor( + aparam_inv_std, + device=DEVICE, + dtype=base_class.aparam_inv_std.dtype, + ) + ) + base_class._param_stats["aparam"] = merged + self._buffers["aparam_avg"] = base_class._buffers["aparam_avg"] + self._buffers["aparam_inv_std"] = base_class._buffers["aparam_inv_std"] + + # --- share modules and remaining buffers --- + for item in list(self._modules): + if item in ("bias_atom_e", "case_embd"): + continue + self._modules[item] = base_class._modules[item] + for item in list(self._buffers): + if item in ( + "bias_atom_e", + "case_embd", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ): + continue + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 8f32ca660c..1ae16c92b1 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -22,6 +22,7 @@ import numpy as np import torch +import torch.distributed as dist from deepmd.dpmodel.common import ( to_numpy_array, @@ -141,6 +142,28 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: # --------------------------------------------------------------------------- +def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None: + """Remove ``aten.detach.default`` nodes from an FX graph in-place. + + ``make_fx`` inserts these nodes when recording saved tensors from the + autograd backward pass (``autograd.grad`` with ``create_graph=True``). + The detach breaks the gradient connection between saved activations and + model parameters, causing incorrect second-order derivatives — e.g. + bias gradients become zero for force-loss training. + + Removing these nodes restores the gradient path so that higher-order + derivatives flow correctly through the decomposed backward ops. + """ + graph = gm.graph + for node in list(graph.nodes): + if node.op == "call_function" and node.target == torch.ops.aten.detach.default: + input_node = node.args[0] + node.replace_all_uses_with(input_node) + graph.erase_node(node) + graph.lint() + gm.recompile() + + def _trace_and_compile( model: torch.nn.Module, ext_coord: torch.Tensor, @@ -151,16 +174,17 @@ def _trace_and_compile( aparam: torch.Tensor | None, compile_opts: dict[str, Any], ) -> torch.nn.Module: - """Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``. + """Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True. Parameters ---------- model : torch.nn.Module - The (uncompiled) model. Temporarily set to eval mode for tracing. + The (uncompiled) model. ext_coord, ext_atype, nlist, mapping, fparam, aparam - Sample tensors (already padded to the desired max_nall). + Sample tensors used to seed the symbolic tracer. compile_opts : dict - Options forwarded to ``torch.compile`` (excluding ``dynamic``). + Options forwarded to ``torch.compile`` (the ``dynamic`` and + ``backend`` keys are ignored and replaced). Returns ------- @@ -197,84 +221,91 @@ def fn( aparam=aparam, ) - # Use default tracing_mode="real" (concrete shapes) for best - # runtime performance. If data-dependent intermediate shapes - # change at runtime, the caller catches the error and retraces. - traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam) + # Pick a trace-time nframes that's unlikely to collide with any other + # tensor dim in the graph. The symbolic tracer merges symbols that + # are numerically equal at trace time, which bakes nframes into the + # compiled graph whenever it matches e.g. numb_fparam, numb_aparam, + # ntypes, axis_neuron, or neuron sizes (8, 16, 32, ...). Using a + # prime value of 7 avoids the common small-dim collisions while still + # being cheap to trace. + _TRACE_NFRAMES = 7 + cur_nframes = ext_coord.shape[0] + if cur_nframes != _TRACE_NFRAMES: + + def _expand(t: torch.Tensor | None) -> torch.Tensor | None: + if t is None: + return None + # Repeat rows so total nframes == _TRACE_NFRAMES. Use index + # gather (mod) so we don't require divisibility. + idx = ( + torch.arange(_TRACE_NFRAMES, dtype=torch.long, device=t.device) + % cur_nframes + ) + return t.index_select(0, idx) + + ext_coord = _expand(ext_coord) + ext_atype = _expand(ext_atype) + nlist = _expand(nlist) + mapping = _expand(mapping) + fparam = _expand(fparam) + aparam = _expand(aparam) + + # Decompose silu_backward into primitive ops (sigmoid + mul + ...) + # so that inductor can compile the graph without requiring a + # higher-order derivative that PyTorch does not register for + # the fused silu backward kernel. + from torch._decomp import ( + get_decompositions, + ) + + decomp_table = get_decompositions([torch.ops.aten.silu_backward.default]) + + traced_lower = make_fx( + fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + decomposition_table=decomp_table, + )(ext_coord, ext_atype, nlist, mapping, fparam, aparam) + + # make_fx inserts aten.detach.default for saved tensors used in the + # decomposed autograd.grad backward ops. These detach nodes break + # second-order gradient flow (d(force)/d(params) for force training). + # Removing them restores correct higher-order derivatives. + _remove_detach_nodes(traced_lower) if not was_training: model.eval() - # The inductor backend does not propagate gradients through the - # make_fx-decomposed autograd.grad ops (second-order gradients for - # force training). Use "aot_eager" which correctly preserves the - # gradient chain while still benefiting from make_fx decomposition. - if "backend" not in compile_opts: - compile_opts["backend"] = "aot_eager" - compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts) - return compiled_lower + # Work on a copy; ignore caller-supplied dynamic/backend. + compile_opts = { + k: v for k, v in compile_opts.items() if k not in ("dynamic", "backend") + } + opts = compile_opts.setdefault("options", {}) + opts.setdefault("max_autotune", False) + opts.setdefault("epilogue_fusion", False) + opts.setdefault("triton.cudagraphs", False) + opts.setdefault("shape_padding", True) + opts.setdefault("max_fusion_size", 8) + + return torch.compile( + traced_lower, + backend="inductor", + dynamic=True, + **compile_opts, + ) class _CompiledModel(torch.nn.Module): - """Coord extension (eager) -> pad nall -> compiled forward_lower. - - If a batch's ``nall`` exceeds the current ``max_nall``, the model is - automatically re-traced and recompiled with a larger pad size. - """ + """Coord extension (eager) -> compiled forward_lower (dynamic shapes).""" def __init__( self, original_model: torch.nn.Module, compiled_forward_lower: torch.nn.Module, - max_nall: int, - compile_opts: dict[str, Any], ) -> None: super().__init__() self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower - self._max_nall = max_nall - self._compile_opts = compile_opts - - def _recompile( - self, - ext_coord: torch.Tensor, - ext_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor, - fparam: torch.Tensor | None, - aparam: torch.Tensor | None, - new_max_nall: int, - ) -> None: - """Re-trace and recompile for the given inputs. - - If *new_max_nall* differs from the current ``_max_nall``, the - inputs are padded (or already padded by the caller). - """ - # Pad if the caller provides unpadded tensors (nall growth case) - actual_nall = ext_coord.shape[1] - pad_n = new_max_nall - actual_nall - if pad_n > 0: - ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) - ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) - mapping = torch.nn.functional.pad(mapping, (0, pad_n)) - - ext_coord = ext_coord.detach() - - self.compiled_forward_lower = _trace_and_compile( - self.original_model, - ext_coord, - ext_atype, - nlist, - mapping, - fparam, - aparam, - self._compile_opts, - ) - self._max_nall = new_max_nall - log.info( - "Recompiled model with max_nall=%d.", - new_max_nall, - ) def forward( self, @@ -318,27 +349,6 @@ def forward( distinguish_types=False, ) ext_coord = ext_coord.reshape(nframes, -1, 3) - - # Grow max_nall if needed (retrace + recompile) - actual_nall = ext_coord.shape[1] - if actual_nall > self._max_nall: - new_max_nall = ((int(actual_nall * 1.2) + 7) // 8) * 8 - log.info( - "nall=%d exceeds max_nall=%d; recompiling with max_nall=%d.", - actual_nall, - self._max_nall, - new_max_nall, - ) - self._recompile( - ext_coord, ext_atype, nlist, mapping, fparam, aparam, new_max_nall - ) - - # Pad to max_nall so compiled graph sees a fixed shape - pad_n = self._max_nall - actual_nall - if pad_n > 0: - ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) - ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) - mapping = torch.nn.functional.pad(mapping, (0, pad_n)) ext_coord = ext_coord.detach().requires_grad_(True) result = self.compiled_forward_lower( @@ -354,18 +364,12 @@ def forward( out["atom_energy"] = result["atom_energy"] out["energy"] = result["energy"] if "extended_force" in result: - ext_force = result["extended_force"] # (nf, nall_padded, 3) - # mapping may be padded; only use actual_nall entries - map_actual = mapping[:, :actual_nall] # (nf, actual_nall) - ext_force_actual = ext_force[:, :actual_nall, :] # (nf, actual_nall, 3) - # scatter-sum extended forces onto local atoms - idx = map_actual.unsqueeze(-1).expand_as( - ext_force_actual - ) # (nf, actual_nall, 3) + ext_force = result["extended_force"] # (nf, nall, 3) + idx = mapping.unsqueeze(-1).expand_as(ext_force) # (nf, nall, 3) force = torch.zeros( nframes, nloc, 3, dtype=ext_force.dtype, device=ext_force.device ) - force.scatter_add_(1, idx, ext_force_actual) + force.scatter_add_(1, idx, ext_force) out["force"] = force if "virial" in result: out["virial"] = result["virial"] @@ -387,34 +391,38 @@ class Trainer: """Training driver for the pt_expt backend. Uses ``DeepmdDataSystem`` for data loading (numpy batches converted - to torch tensors at the boundary). Single-task, single-GPU only. + to torch tensors at the boundary). Supports single-task and multi-task + training. Single-GPU only. Parameters ---------- config : dict Full training configuration. - training_data : DeepmdDataSystem - Training data. - stat_file_path : DPPath or None + training_data : DeepmdDataSystem or dict + Training data. Dict of ``{model_key: DeepmdDataSystem}`` for multi-task. + stat_file_path : DPPath or dict or None Path for saving / loading statistics. - validation_data : DeepmdDataSystem or None + validation_data : DeepmdDataSystem or dict or None Validation data. init_model : str or None Path to a checkpoint to initialise weights from. restart_model : str or None Path to a checkpoint to *restart* training from (restores step + optimiser). + shared_links : dict or None + Parameter sharing rules for multi-task training. """ def __init__( self, config: dict[str, Any], - training_data: DeepmdDataSystem, - stat_file_path: DPPath | None = None, - validation_data: DeepmdDataSystem | None = None, + training_data: DeepmdDataSystem | dict, + stat_file_path: DPPath | dict | None = None, + validation_data: DeepmdDataSystem | dict | None = None, init_model: str | None = None, restart_model: str | None = None, finetune_model: str | None = None, finetune_links: dict | None = None, + shared_links: dict | None = None, ) -> None: if finetune_model is not None and ( init_model is not None or restart_model is not None @@ -429,6 +437,18 @@ def __init__( model_params = config["model"] training_params = config["training"] + # Multi-task detection + self.multi_task = "model_dict" in model_params + self.model_keys = ( + list(model_params["model_dict"]) if self.multi_task else ["Default"] + ) + self.num_model = len(self.model_keys) + + # Distributed training detection + self.is_distributed = dist.is_available() and dist.is_initialized() + self.rank = dist.get_rank() if self.is_distributed else 0 + self.world_size = dist.get_world_size() if self.is_distributed else 1 + # Iteration config self.num_steps = training_params["numb_steps"] self.disp_file = training_params.get("disp_file", "lcurve.out") @@ -440,47 +460,137 @@ def __init__( self.lcurve_should_print_header = True # Model --------------------------------------------------------------- - self.model = get_model(deepcopy(model_params)).to(DEVICE) + if not self.multi_task: + self.model = get_model(deepcopy(model_params)).to(DEVICE) + else: + self.model = {} + do_case_embd, case_embd_index = _get_case_embd_config(model_params) + for model_key in self.model_keys: + self.model[model_key] = get_model( + deepcopy(model_params["model_dict"][model_key]) + ).to(DEVICE) + if do_case_embd and not resuming: + self.model[model_key].set_case_embd(case_embd_index[model_key]) # Loss ---------------------------------------------------------------- - self.loss = get_loss( - config.get("loss", {}), - config["learning_rate"]["start_lr"], - len(model_params["type_map"]), - self.model, - ) + if not self.multi_task: + self.loss = get_loss( + config.get("loss", {}), + config["learning_rate"]["start_lr"], + len(model_params["type_map"]), + self.model, + ) + else: + self.loss = {} + for model_key in self.model_keys: + loss_param = config["loss_dict"][model_key] + lr_param = config["learning_rate"]["start_lr"] + ntypes = len(model_params["model_dict"][model_key]["type_map"]) + self.loss[model_key] = get_loss( + loss_param, lr_param, ntypes, self.model[model_key] + ) # Data requirements --------------------------------------------------- - data_requirement = self.loss.label_requirement - data_requirement += get_additional_data_requirement(self.model) - training_data.add_data_requirements(data_requirement) - if validation_data is not None: - validation_data.add_data_requirements(data_requirement) - - self.training_data = training_data - self.validation_data = validation_data - self.valid_numb_batch = training_params.get("validation_data", {}).get( - "numb_btch", 1 - ) + if not self.multi_task: + data_requirement = self.loss.label_requirement + data_requirement += get_additional_data_requirement(self.model) + training_data.add_data_requirements(data_requirement) + if validation_data is not None: + validation_data.add_data_requirements(data_requirement) + + self.training_data = training_data + self.validation_data = validation_data + self.valid_numb_batch = training_params.get("validation_data", {}).get( + "numb_btch", 1 + ) + else: + self.training_data = {} + self.validation_data = {} + self.valid_numb_batch = {} + for model_key in self.model_keys: + data_requirement = self.loss[model_key].label_requirement + data_requirement += get_additional_data_requirement( + self.model[model_key] + ) + training_data[model_key].add_data_requirements(data_requirement) + if validation_data[model_key] is not None: + validation_data[model_key].add_data_requirements(data_requirement) + self.training_data[model_key] = training_data[model_key] + self.validation_data[model_key] = validation_data[model_key] + self.valid_numb_batch[model_key] = ( + training_params["data_dict"][model_key] + .get("validation_data", {}) + .get("numb_btch", 1) + ) # Statistics ---------------------------------------------------------- - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + if not self.multi_task: + data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - @functools.lru_cache - def get_sample() -> list[dict[str, np.ndarray]]: - return make_stat_input(training_data, data_stat_nbatch) + @functools.lru_cache + def get_sample() -> list[dict[str, np.ndarray]]: + return make_stat_input(training_data, data_stat_nbatch) - finetune_has_new_type = ( - finetune_model is not None - and finetune_links is not None - and finetune_links["Default"].get_has_new_type() - ) - if not resuming or finetune_has_new_type: - self.model.compute_or_load_stat( - sampled_func=get_sample, - stat_file_path=stat_file_path, + finetune_has_new_type = ( + finetune_model is not None + and finetune_links is not None + and finetune_links["Default"].get_has_new_type() + ) + if (not resuming or finetune_has_new_type) and self.rank == 0: + self.model.compute_or_load_stat( + sampled_func=get_sample, + stat_file_path=stat_file_path, + ) + if self.is_distributed: + self._broadcast_model_stat(self.model) + else: + self._finetune_update_stat = False + self._sample_funcs: dict[str, Any] = {} + for model_key in self.model_keys: + _nbatch = model_params["model_dict"][model_key].get( + "data_stat_nbatch", 10 + ) + _data = training_data[model_key] + _stat_path = stat_file_path[model_key] if stat_file_path else None + + def _make_sample( + _d: DeepmdDataSystem = _data, _n: int = _nbatch + ) -> list[dict[str, np.ndarray]]: + return make_stat_input(_d, _n) + + self._sample_funcs[model_key] = _make_sample + + _finetune_has_new_type = ( + finetune_model is not None + and finetune_links is not None + and model_key in finetune_links + and finetune_links[model_key].get_has_new_type() + ) + if _finetune_has_new_type: + self._finetune_update_stat = True + if (not resuming or _finetune_has_new_type) and self.rank == 0: + self.model[model_key].compute_or_load_stat( + sampled_func=_make_sample, + stat_file_path=_stat_path, + ) + if self.is_distributed: + for model_key in self.model_keys: + self._broadcast_model_stat(self.model[model_key]) + + # Model probability (multi-task) -------------------------------------- + if self.multi_task: + from deepmd.dpmodel.utils.training_utils import ( + resolve_model_prob, ) + self.model_prob = resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + training_data, + ) + else: + self.model_prob = None + # Learning rate ------------------------------------------------------- lr_params = config["learning_rate"].copy() lr_params["num_steps"] = self.num_steps @@ -493,6 +603,51 @@ def get_sample() -> list[dict[str, np.ndarray]]: self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) self.start_step = 0 + # Shared params (multi-task) ------------------------------------------ + if shared_links is not None: + _data_stat_protect = np.array( + [ + model_params["model_dict"][ii].get("data_stat_protect", 1e-2) + for ii in model_params["model_dict"] + ] + ) + if not np.allclose(_data_stat_protect, _data_stat_protect[0]): + raise ValueError( + "Model key 'data_stat_protect' must be the same in each branch when multitask!" + ) + self.wrapper.share_params( + shared_links, + resume=(resuming and not self._finetune_update_stat) or self.rank != 0, + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ), + data_stat_protect=_data_stat_protect[0], + ) + + # DDP wrapping -------------------------------------------------------- + if self.is_distributed: + # Multi-task uses only one fitting_net per step, so unused + # parameters exist in the graph. Single-task doesn't need this. + _find_unused = self.multi_task + if DEVICE.type == "cuda": + from deepmd.pt_expt.utils.env import ( + LOCAL_RANK, + ) + + torch.cuda.set_device(LOCAL_RANK) + self.wrapper = torch.nn.parallel.DistributedDataParallel( + self.wrapper, + device_ids=[LOCAL_RANK], + find_unused_parameters=_find_unused, + output_device=LOCAL_RANK, + ) + else: + # CPU (gloo backend) — no device_ids + self.wrapper = torch.nn.parallel.DistributedDataParallel( + self.wrapper, + find_unused_parameters=_find_unused, + ) + # Optimiser ----------------------------------------------------------- opt_type = training_params.get("opt_type", "Adam") initial_lr = float(self.lr_schedule.value(self.start_step)) @@ -545,9 +700,8 @@ def get_sample() -> list[dict[str, np.ndarray]]: if finetune_model is not None and finetune_links is not None: # --- Finetune: selective weight loading ----------------------- - finetune_rule = finetune_links["Default"] - # Build pretrained model and load weights + # Build pretrained model(s) and load weights if is_pte: from deepmd.pt_expt.model import ( BaseModel, @@ -557,58 +711,127 @@ def get_sample() -> list[dict[str, np.ndarray]]: ) data = serialize_from_file(finetune_model) + pretrained_model_params = data["model_def_script"] pretrained_model = BaseModel.deserialize(data["model"]).to(DEVICE) else: - pretrained_model = get_model( - deepcopy(state_dict["_extra_state"]["model_params"]) - ).to(DEVICE) - pretrained_wrapper = ModelWrapper(pretrained_model) + pretrained_model_params = state_dict["_extra_state"]["model_params"] + + # Build pretrained model (single-task or multi-task) + if "model_dict" not in pretrained_model_params: + # Single-task pretrained → wrap as {"Default": model} + if is_pte: + pretrained_models = pretrained_model + else: + pretrained_models = get_model( + deepcopy(pretrained_model_params) + ).to(DEVICE) + else: + pretrained_models = {} + for pk in pretrained_model_params["model_dict"]: + pretrained_models[pk] = get_model( + deepcopy(pretrained_model_params["model_dict"][pk]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_models) if not is_pte: pretrained_wrapper.load_state_dict(state_dict) - # Change type map if needed - if ( - finetune_rule.get_finetune_tmap() - != pretrained_wrapper.model.get_type_map() - ): - model_with_new_type_stat = ( - self.wrapper.model if finetune_rule.get_has_new_type() else None - ) - pretrained_wrapper.model.change_type_map( - finetune_rule.get_finetune_tmap(), - model_with_new_type_stat=model_with_new_type_stat, - ) + # Per-branch type map change + for model_key in self.model_keys: + finetune_rule = finetune_links[model_key] + _model_key_from = finetune_rule.get_model_branch() + if ( + finetune_rule.get_finetune_tmap() + != pretrained_wrapper.model[_model_key_from].get_type_map() + ): + model_with_new_type_stat = ( + self._unwrapped.model[model_key] + if finetune_rule.get_has_new_type() + else None + ) + pretrained_wrapper.model[_model_key_from].change_type_map( + finetune_rule.get_finetune_tmap(), + model_with_new_type_stat=model_with_new_type_stat, + ) - # Selectively copy weights: descriptor always from pretrained, - # fitting from pretrained unless random_fitting is True + # Selective weight copy (per-branch key remapping) pretrained_state = pretrained_wrapper.state_dict() - target_state = self.wrapper.state_dict() + target_state = self._unwrapped.state_dict() new_state = {} for key in target_state: if key == "_extra_state": new_state[key] = target_state[key] - elif ( - finetune_rule.get_random_fitting() and ".descriptor." not in key - ): - new_state[key] = target_state[key] # keep random init - elif key in pretrained_state: - new_state[key] = pretrained_state[key] # from pretrained - else: - new_state[key] = target_state[key] # fallback - self.wrapper.load_state_dict(new_state) - - # Adjust output bias - bias_mode = ( - "change-by-statistic" - if not finetune_rule.get_random_fitting() - else "set-by-statistic" - ) - self.model = model_change_out_bias( - self.model, get_sample, _bias_adjust_mode=bias_mode - ) + continue + # Find which model_key this key belongs to + matched = False + for model_key in self.model_keys: + if f".{model_key}." not in key: + continue + matched = True + finetune_rule = finetune_links[model_key] + _key_from = finetune_rule.get_model_branch() + pretrained_key = key.replace(f".{model_key}.", f".{_key_from}.") + use_random = ( + finetune_rule.get_random_fitting() + and ".descriptor." not in key + ) + if use_random: + new_state[key] = target_state[key] + elif pretrained_key in pretrained_state: + new_state[key] = pretrained_state[pretrained_key] + else: + new_state[key] = target_state[key] + break + if not matched: + new_state[key] = target_state[key] + self._unwrapped.load_state_dict(new_state) + + # Per-branch bias adjustment (rank 0 only, then broadcast) + if not self.multi_task: + finetune_rule = finetune_links["Default"] + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + if self.rank == 0: + self.model = model_change_out_bias( + self.model, get_sample, _bias_adjust_mode=bias_mode + ) + if self.is_distributed: + self._broadcast_model_stat(self.model) + else: + for model_key in self.model_keys: + finetune_rule = finetune_links[model_key] + if finetune_rule.get_resuming(): + log.info(f"Model branch {model_key} will resume training.") + continue + log.info(f"Model branch {model_key} will be fine-tuned.") + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + if self.rank == 0: + self.model[model_key] = model_change_out_bias( + self.model[model_key], + self._sample_funcs[model_key], + _bias_adjust_mode=bias_mode, + ) + if self.is_distributed: + self._broadcast_model_stat(self.model[model_key]) else: # --- Normal resume (init_model / restart) -------------------- - self.wrapper.load_state_dict(state_dict) + self._unwrapped.load_state_dict(state_dict) + + if shared_links is not None: + # Re-apply sharing after loading checkpoint + self._unwrapped.share_params( + shared_links, + resume=True, + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ), + ) if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) @@ -622,13 +845,6 @@ def get_sample() -> list[dict[str, np.ndarray]]: ) # torch.compile ------------------------------------------------------- - # The model's forward uses torch.autograd.grad (for forces) with - # create_graph=True so the loss backward can differentiate through - # forces. torch.compile does not support this "double backward". - # - # Solution: use make_fx to trace the model forward, which decomposes - # torch.autograd.grad into primitive ops. The resulting traced - # module is then compiled by torch.compile — no double backward. self.enable_compile = training_params.get("enable_compile", False) if self.enable_compile: compile_opts = training_params.get("compile_options", {}) @@ -646,17 +862,13 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: computation) with ``create_graph=True``, which creates a "double backward" that ``torch.compile`` cannot handle. - Solution: use ``make_fx`` to trace ``forward_lower``, decomposing - ``torch.autograd.grad`` into primitive ops. The coord extension + - nlist build (data-dependent control flow) are kept outside the - compiled region. - - To avoid the overhead of symbolic tracing and dynamic shapes, the - extended-atom dimension (nall) is padded to a fixed maximum - estimated from the training data. This allows concrete-shape - tracing and ``dynamic=False``. If a batch exceeds the current - max_nall at runtime, the model is automatically re-traced and - recompiled with a larger pad size. + Solution: use ``make_fx`` in ``tracing_mode="symbolic"`` to trace + ``forward_lower``, decomposing ``torch.autograd.grad`` into + primitive ops. The symbolic trace keeps the extended-atom + dimension (``nall``) and batch dimension (``nframes``) as + symbolic shapes, so no padding or recompile-on-growth logic is + needed. The coord extension + nlist build (data-dependent + control flow) are kept outside the compiled region. """ from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, @@ -666,17 +878,18 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: normalize_coord, ) - model = self.model + # Under DDP, self.wrapper is a DistributedDataParallel wrapper; + # access the underlying ModelWrapper via .module. + wrapper_mod = ( + self.wrapper.module + if isinstance(self.wrapper, torch.nn.parallel.DistributedDataParallel) + else self.wrapper + ) - # --- Estimate max_nall by sampling multiple batches --- - n_sample = 20 - max_nall = 0 - best_sample: ( - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None - ) = None + for task_key in self.model_keys: + model = wrapper_mod.model[task_key] - for _ii in range(n_sample): - inp, _ = self.get_data(is_train=True) + inp, _ = self.get_data(is_train=True, task_key=task_key) coord = inp["coord"].detach() atype = inp["atype"].detach() box = inp.get("box") @@ -684,90 +897,47 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: box = box.detach() nframes, nloc = atype.shape[:2] - coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3) - atype_np = atype.cpu().numpy() - box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None + coord_3d = coord.reshape(nframes, nloc, 3) + box_flat = box.reshape(nframes, 9) if box is not None else None - if box_np is not None: - coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3)) + if box_flat is not None: + coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3)) else: - coord_norm = coord_np + coord_norm = coord_3d - ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts( - coord_norm, atype_np, box_np, model.get_rcut() + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype, box_flat, model.get_rcut() ) - nlist_np = build_neighbor_list( - ext_coord_np, - ext_atype_np, + nlist_t = build_neighbor_list( + ext_coord, + ext_atype, nloc, model.get_rcut(), model.get_sel(), distinguish_types=False, ) - ext_coord_np = ext_coord_np.reshape(nframes, -1, 3) - nall = ext_coord_np.shape[1] - if nall > max_nall: - max_nall = nall - best_sample = ( - ext_coord_np, - ext_atype_np, - mapping_np, - nlist_np, - nloc, - inp, - ) - - # Add 20 % margin and round up to a multiple of 8. - max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8 - log.info( - "Estimated max_nall=%d for compiled model (sampled %d batches).", - max_nall, - n_sample, - ) - - # --- Pad the largest sample to max_nall and trace --- - assert best_sample is not None - ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( - best_sample - ) - nframes = ext_coord_np.shape[0] - actual_nall = ext_coord_np.shape[1] - pad_n = max_nall - actual_nall - - if pad_n > 0: - ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0))) - ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n))) - mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n))) - - ext_coord = torch.tensor( - ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE) - nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE) - mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE) - fparam = sample_input.get("fparam") - aparam = sample_input.get("aparam") - - compile_opts.pop("dynamic", None) # always False for padded approach - - compiled_lower = _trace_and_compile( - model, - ext_coord, - ext_atype, - nlist_t, - mapping_t, - fparam, - aparam, - compile_opts, - ) + ext_coord = ext_coord.reshape(nframes, -1, 3) + + fparam = inp.get("fparam") + aparam = inp.get("aparam") + + compiled_lower = _trace_and_compile( + model, + ext_coord, + ext_atype, + nlist_t, + mapping, + fparam, + aparam, + compile_opts, + ) - self.wrapper.model = _CompiledModel( - model, compiled_lower, max_nall, compile_opts - ) - log.info( - "Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).", - max_nall, - ) + wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower) + log.info( + "Model compiled (task=%s, tracing_mode=symbolic, " + "dynamic=True, backend=inductor).", + task_key, + ) # ------------------------------------------------------------------ # Data helpers @@ -776,14 +946,29 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: def get_data( self, is_train: bool = True, + task_key: str = "Default", ) -> tuple[dict[str, Any], dict[str, Any]]: """Fetch a batch and split into input / label dicts. + Parameters + ---------- + is_train : bool + Whether to fetch from training or validation data. + task_key : str + Task key for multi-task training. + Returns ------- input_dict, label_dict """ - data_sys = self.training_data if is_train else self.validation_data + if not self.multi_task: + data_sys = self.training_data if is_train else self.validation_data + else: + data_sys = ( + self.training_data[task_key] + if is_train + else self.validation_data[task_key] + ) if data_sys is None: return {}, {} @@ -812,14 +997,33 @@ def get_data( return input_dict, label_dict + # ------------------------------------------------------------------ + # DDP helpers + # ------------------------------------------------------------------ + + @property + def _unwrapped(self) -> "ModelWrapper": + """Return the raw ModelWrapper, unwrapping DDP if active.""" + if hasattr(self.wrapper, "module"): + return self.wrapper.module + return self.wrapper + + @staticmethod + def _broadcast_model_stat(model: torch.nn.Module) -> None: + """Broadcast model parameters and buffers from rank 0 to all ranks.""" + for p in model.parameters(): + dist.broadcast(p.data, src=0) + for b in model.buffers(): + dist.broadcast(b, src=0) + # ------------------------------------------------------------------ # Checkpointing # ------------------------------------------------------------------ def save_checkpoint(self, step: int) -> None: - self.wrapper.train_infos["step"] = step + self._unwrapped.train_infos["step"] = step state = { - "model": self.wrapper.state_dict(), + "model": self._unwrapped.state_dict(), "optimizer": self.optimizer.state_dict(), } ckpt_path = f"{self.save_ckpt}-{step}.pt" @@ -846,10 +1050,16 @@ def _optimizer_step(self) -> None: self.scheduler.step() def run(self) -> None: - fout = open( - self.disp_file, - mode="w" if not self.restart_training else "a", - buffering=1, + from deepmd.utils import random as dp_random + + fout = ( + open( + self.disp_file, + mode="w" if not self.restart_training else "a", + buffering=1, + ) + if self.rank == 0 + else None ) log.info("Start to train %d steps.", self.num_steps) @@ -860,16 +1070,28 @@ def run(self) -> None: for step_id in range(self.start_step, self.num_steps): cur_lr = float(self.lr_schedule.value(step_id)) + # --- task selection (multi-task) --- + task_key = "Default" + if self.multi_task: + model_index = dp_random.choice( + np.arange(self.num_model, dtype=np.int_), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] + if self.timing_in_training: t_start = time.time() # --- forward / backward --- self.optimizer.zero_grad(set_to_none=True) - input_dict, label_dict = self.get_data(is_train=True) + input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) cur_lr_sched = self.scheduler.get_last_lr()[0] model_pred, loss, more_loss = self.wrapper( - **input_dict, cur_lr=cur_lr_sched, label=label_dict + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + task_key=task_key if self.multi_task else None, ) loss.backward() @@ -890,104 +1112,183 @@ def run(self) -> None: ): self.wrapper.eval() - train_results = {k: v for k, v in more_loss.items() if "l2_" not in k} - - # validation - valid_results: dict[str, Any] = {} - if self.validation_data is not None: - sum_natoms = 0 - for _ii in range(self.valid_numb_batch): - val_input, val_label = self.get_data(is_train=False) - if not val_input: - break - _, _vloss, _vmore = self.wrapper( - **val_input, cur_lr=cur_lr_sched, label=val_label - ) - natoms = int(val_input["atype"].shape[-1]) - sum_natoms += natoms - for k, v in _vmore.items(): - if "l2_" not in k: - valid_results[k] = ( - valid_results.get(k, 0.0) + v * natoms + if self.rank == 0: + if not self.multi_task: + train_results = { + k: v for k, v in more_loss.items() if "l2_" not in k + } + + # validation + valid_results: dict[str, Any] = {} + if self.validation_data is not None: + sum_natoms = 0 + for _ii in range(self.valid_numb_batch): + val_input, val_label = self.get_data(is_train=False) + if not val_input: + break + _, _vloss, _vmore = self._unwrapped( + **val_input, + cur_lr=cur_lr_sched, + label=val_label, ) - if sum_natoms > 0: - valid_results = { - k: v / sum_natoms for k, v in valid_results.items() + natoms = int(val_input["atype"].shape[-1]) + sum_natoms += natoms + for k, v in _vmore.items(): + if "l2_" not in k: + valid_results[k] = ( + valid_results.get(k, 0.0) + v * natoms + ) + if sum_natoms > 0: + valid_results = { + k: v / sum_natoms for k, v in valid_results.items() + } + else: + # Multi-task: compute loss for ALL tasks + train_results = {_key: {} for _key in self.model_keys} + valid_results = {_key: {} for _key in self.model_keys} + + # current task already has loss + train_results[task_key] = { + k: v for k, v in more_loss.items() if "l2_" not in k } - # wall-clock time - current_time = time.time() - wall_elapsed = current_time - wall_start - interval_wall_time = current_time - last_log_time - last_log_time = current_time - if self.timing_in_training: - step_time = t_end - t_start - steps_completed_since_restart = max( - 1, - display_step_id - self.start_step, - ) - eta = int( - (self.num_steps - display_step_id) - / steps_completed_since_restart - * wall_elapsed - ) - log.info( - format_training_message( - batch=display_step_id, - wall_time=interval_wall_time, - eta=eta, - current_time=datetime.datetime.fromtimestamp( - current_time, - tz=datetime.timezone.utc, - ).astimezone(), + # compute loss for other tasks + for _key in self.model_keys: + if _key != task_key: + self.optimizer.zero_grad() + _inp, _lab = self.get_data(is_train=True, task_key=_key) + _, _loss, _more = self._unwrapped( + **_inp, + cur_lr=cur_lr_sched, + label=_lab, + task_key=_key, + ) + train_results[_key] = { + k: v for k, v in _more.items() if "l2_" not in k + } + + # validation for each task + _vdata = self.validation_data[_key] + if _vdata is not None: + _sum_natoms = 0 + _vres: dict[str, Any] = {} + for _ii in range(self.valid_numb_batch[_key]): + _vi, _vl = self.get_data( + is_train=False, task_key=_key + ) + if not _vi: + break + _, _vloss, _vmore = self._unwrapped( + **_vi, + cur_lr=cur_lr_sched, + label=_vl, + task_key=_key, + ) + natoms = int(_vi["atype"].shape[-1]) + _sum_natoms += natoms + for k, v in _vmore.items(): + if "l2_" not in k: + _vres[k] = _vres.get(k, 0.0) + v * natoms + if _sum_natoms > 0: + _vres = { + k: v / _sum_natoms for k, v in _vres.items() + } + valid_results[_key] = _vres + # wall-clock time + current_time = time.time() + wall_elapsed = current_time - wall_start + interval_wall_time = current_time - last_log_time + last_log_time = current_time + if self.timing_in_training: + step_time = t_end - t_start + steps_completed_since_restart = max( + 1, + display_step_id - self.start_step, ) - ) - log.info("step=%d step_time=%.4fs", display_step_id, step_time) - else: - log.info( - format_training_message( - batch=display_step_id, - wall_time=interval_wall_time, + eta = int( + (self.num_steps - display_step_id) + / steps_completed_since_restart + * wall_elapsed + ) + log.info( + format_training_message( + batch=display_step_id, + wall_time=interval_wall_time, + eta=eta, + current_time=datetime.datetime.fromtimestamp( + current_time, + tz=datetime.timezone.utc, + ).astimezone(), + ) + ) + log.info("step=%d step_time=%.4fs", display_step_id, step_time) + else: + log.info( + format_training_message( + batch=display_step_id, + wall_time=interval_wall_time, + ) ) - ) - # log - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name="trn", - rmse=train_results, - learning_rate=cur_lr, - ) - ) - if valid_results: - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name="val", - rmse=valid_results, - learning_rate=None, + # log + if not self.multi_task: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name="trn", + rmse=train_results, + learning_rate=cur_lr, + ) ) - ) + if valid_results: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + else: + for _key in self.model_keys: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name=_key + "_trn", + rmse=train_results[_key], + learning_rate=cur_lr, + ) + ) + if valid_results[_key]: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name=_key + "_val", + rmse=valid_results[_key], + learning_rate=None, + ) + ) - # lcurve file - if self.lcurve_should_print_header: - self.print_header(fout, train_results, valid_results) - self.lcurve_should_print_header = False - self.print_on_training( - fout, display_step_id, cur_lr, train_results, valid_results - ) + # lcurve file + if self.lcurve_should_print_header: + self.print_header(fout, train_results, valid_results) + self.lcurve_should_print_header = False + self.print_on_training( + fout, display_step_id, cur_lr, train_results, valid_results + ) self.wrapper.train() # --- checkpoint --- - if display_step_id % self.save_freq == 0: + if display_step_id % self.save_freq == 0 and self.rank == 0: self.save_checkpoint(display_step_id) # final save - self.save_checkpoint(self.num_steps) + if self.rank == 0: + self.save_checkpoint(self.num_steps) wall_total = time.time() - wall_start - fout.close() + if fout is not None: + fout.close() log.info("Training finished. Total wall time: %.2fs", wall_total) # ------------------------------------------------------------------ @@ -1000,14 +1301,23 @@ def print_header( train_results: dict[str, Any], valid_results: dict[str, Any], ) -> None: - train_keys = sorted(train_results.keys()) header = "# {:5s}".format("step") - if valid_results: - for k in train_keys: - header += f" {k + '_val':>11s} {k + '_trn':>11s}" + if not self.multi_task: + train_keys = sorted(train_results.keys()) + if valid_results: + for k in train_keys: + header += f" {k + '_val':>11s} {k + '_trn':>11s}" + else: + for k in train_keys: + header += f" {k + '_trn':>11s}" else: - for k in train_keys: - header += f" {k + '_trn':>11s}" + for model_key in self.model_keys: + if valid_results[model_key]: + for k in sorted(train_results[model_key].keys()): + header += f" {k + '_val_' + model_key:>11s} {k + '_trn_' + model_key:>11s}" + else: + for k in sorted(train_results[model_key].keys()): + header += f" {k + '_trn_' + model_key:>11s}" header += " {:8s}\n".format("lr") fout.write(header) fout.flush() @@ -1020,14 +1330,23 @@ def print_on_training( train_results: dict, valid_results: dict, ) -> None: - train_keys = sorted(train_results.keys()) line = f"{step_id:7d}" - if valid_results: - for k in train_keys: - line += f" {valid_results.get(k, float('nan')):11.2e} {train_results[k]:11.2e}" + if not self.multi_task: + train_keys = sorted(train_results.keys()) + if valid_results: + for k in train_keys: + line += f" {valid_results.get(k, float('nan')):11.2e} {train_results[k]:11.2e}" + else: + for k in train_keys: + line += f" {train_results[k]:11.2e}" else: - for k in train_keys: - line += f" {train_results[k]:11.2e}" + for model_key in self.model_keys: + if valid_results[model_key]: + for k in sorted(valid_results[model_key].keys()): + line += f" {valid_results[model_key][k]:11.2e} {train_results[model_key][k]:11.2e}" + else: + for k in sorted(train_results[model_key].keys()): + line += f" {train_results[model_key][k]:11.2e}" line += f" {cur_lr:8.1e}\n" fout.write(line) fout.flush() @@ -1074,3 +1393,40 @@ def model_change_out_bias( f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}." ) return _model + + +def _get_case_embd_config( + model_params: dict[str, Any], +) -> tuple[bool, dict[str, int]]: + """Check whether case embedding is enabled and build the index map. + + Parameters + ---------- + model_params : dict + Model parameters containing ``model_dict``. + + Returns + ------- + do_case_embd : bool + Whether case embedding is enabled. + case_embd_index : dict + Mapping from model key to case index (sorted alphabetically). + """ + assert "model_dict" in model_params, ( + "Only support setting case embedding for multi-task model!" + ) + model_keys = list(model_params["model_dict"]) + sorted_model_keys = sorted(model_keys) + numb_case_embd_list = [ + model_params["model_dict"][mk].get("fitting_net", {}).get("dim_case_embd", 0) + for mk in sorted_model_keys + ] + if not all(item == numb_case_embd_list[0] for item in numb_case_embd_list): + raise ValueError( + "All models must have the same dimension of case embedding, " + f"while the settings are: {numb_case_embd_list}" + ) + if numb_case_embd_list[0] == 0: + return False, {} + case_embd_index = {mk: idx for idx, mk in enumerate(sorted_model_keys)} + return True, case_embd_index diff --git a/deepmd/pt_expt/train/wrapper.py b/deepmd/pt_expt/train/wrapper.py index 281168cdba..f67efe8a8e 100644 --- a/deepmd/pt_expt/train/wrapper.py +++ b/deepmd/pt_expt/train/wrapper.py @@ -10,24 +10,24 @@ class ModelWrapper(torch.nn.Module): - """Simplified model wrapper that bundles a model and a loss. + """Model wrapper that bundles model(s) and loss(es). - Single-task only for now (no multi-task support). + Supports both single-task and multi-task training. Parameters ---------- - model : torch.nn.Module - The model to train. - loss : torch.nn.Module - The loss module. + model : torch.nn.Module or dict + Single model or dict of models keyed by task name. + loss : torch.nn.Module or dict or None + Single loss or dict of losses keyed by task name. model_params : dict, optional Model parameters to store as extra state. """ def __init__( self, - model: torch.nn.Module, - loss: torch.nn.Module | None = None, + model: torch.nn.Module | dict, + loss: torch.nn.Module | dict | None = None, model_params: dict[str, Any] | None = None, ) -> None: super().__init__() @@ -36,10 +36,133 @@ def __init__( "lr": 0, "step": 0, } - self.model = model - self.loss = loss + self.multi_task = False + self.model = torch.nn.ModuleDict() + # Model + if isinstance(model, torch.nn.Module): + self.model["Default"] = model + elif isinstance(model, dict): + self.multi_task = True + for task_key in model: + assert isinstance(model[task_key], torch.nn.Module), ( + f"{task_key} in model_dict is not a torch.nn.Module!" + ) + self.model[task_key] = model[task_key] + # Loss — dpmodel losses are not nn.Module, so store in a plain dict. + self.loss: dict[str, Any] | None = None + if loss is not None: + if isinstance(loss, dict): + self.loss = dict(loss) + else: + self.loss = {"Default": loss} self.inference_only = self.loss is None + def share_params( + self, + shared_links: dict[str, Any], + model_key_prob_map: dict, + data_stat_protect: float = 1e-2, + resume: bool = False, + ) -> None: + """Share parameters between models following rules in shared_links. + + Parameters + ---------- + shared_links : dict + Sharing rules from ``preprocess_shared_params``. + model_key_prob_map : dict + Probability map for each model key (for fitting_net stat weighting). + data_stat_protect : float + Protection value for standard deviation computation. + resume : bool + Whether resuming from checkpoint. + """ + for shared_item in shared_links: + shared_base = shared_links[shared_item]["links"][0] + class_type_base = shared_base["shared_type"] + model_key_base = shared_base["model_key"] + shared_level_base = shared_base["shared_level"] + if "descriptor" in class_type_base: + if class_type_base == "descriptor": + base_class = self.model[model_key_base].get_descriptor() + elif "hybrid" in class_type_base: + hybrid_index = int(class_type_base.split("_")[-1]) + base_class = ( + self.model[model_key_base] + .get_descriptor() + .descrpt_list[hybrid_index] + ) + else: + raise RuntimeError(f"Unknown class_type {class_type_base}!") + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert shared_level_link >= shared_level_base, ( + "The shared_links must be sorted by shared_level!" + ) + assert "descriptor" in class_type_link, ( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + if class_type_link == "descriptor": + link_class = self.model[model_key_link].get_descriptor() + elif "hybrid" in class_type_link: + hybrid_index = int(class_type_link.split("_")[-1]) + link_class = ( + self.model[model_key_link] + .get_descriptor() + .descrpt_list[hybrid_index] + ) + else: + raise RuntimeError(f"Unknown class_type {class_type_link}!") + frac_prob = ( + model_key_prob_map[model_key_link] + / model_key_prob_map[model_key_base] + ) + link_class.share_params( + base_class, + shared_level_link, + model_prob=frac_prob, + resume=resume, + ) + log.warning( + f"Shared params of {model_key_base}.{class_type_base} " + f"and {model_key_link}.{class_type_link}!" + ) + else: + if hasattr(self.model[model_key_base].atomic_model, class_type_base): + base_class = self.model[model_key_base].atomic_model.__getattr__( + class_type_base + ) + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert shared_level_link >= shared_level_base, ( + "The shared_links must be sorted by shared_level!" + ) + assert class_type_base == class_type_link, ( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + link_class = self.model[ + model_key_link + ].atomic_model.__getattr__(class_type_link) + frac_prob = ( + model_key_prob_map[model_key_link] + / model_key_prob_map[model_key_base] + ) + link_class.share_params( + base_class, + shared_level_link, + model_prob=frac_prob, + protection=data_stat_protect, + resume=resume, + ) + log.warning( + f"Shared params of {model_key_base}.{class_type_base} " + f"and {model_key_link}.{class_type_link}!" + ) + def forward( self, coord: torch.Tensor, @@ -49,8 +172,16 @@ def forward( aparam: torch.Tensor | None = None, cur_lr: float | torch.Tensor | None = None, label: dict[str, torch.Tensor] | None = None, + task_key: str | None = None, do_atomic_virial: bool = False, ) -> tuple[dict[str, torch.Tensor], torch.Tensor | None, dict | None]: + if not self.multi_task: + task_key = "Default" + else: + assert task_key is not None, ( + f"Multitask model must specify the inference task! " + f"Supported tasks are {list(self.model.keys())}." + ) input_dict = { "coord": coord, "atype": atype, @@ -60,13 +191,13 @@ def forward( "aparam": aparam, } - model_pred = self.model(**input_dict) + model_pred = self.model[task_key](**input_dict) if self.inference_only or label is None: return model_pred, None, None else: natoms = atype.shape[-1] - loss, more_loss = self.loss( + loss, more_loss = self.loss[task_key]( cur_lr, natoms, model_pred, diff --git a/deepmd/pt_expt/utils/finetune.py b/deepmd/pt_expt/utils/finetune.py index 5e49d8738b..473bb43710 100644 --- a/deepmd/pt_expt/utils/finetune.py +++ b/deepmd/pt_expt/utils/finetune.py @@ -47,7 +47,7 @@ def get_finetune_rules( model_branch: str = "", change_model_params: bool = True, ) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: - """Get fine-tuning rules for a single-task pt_expt model. + """Get fine-tuning rules for a single-task or multi-task pt_expt model. Loads a pretrained ``.pt`` checkpoint or ``.pte`` frozen model and builds ``FinetuneRuleItem`` objects describing how to map types and @@ -70,29 +70,83 @@ def get_finetune_rules( model_config : dict Possibly updated model config. finetune_links : dict[str, FinetuneRuleItem] - Fine-tuning rules keyed by ``"Default"``. + Fine-tuning rules keyed by model branch name (``"Default"`` for + single-task, or per-branch keys for multi-task). """ last_model_params = _load_model_params(finetune_model) if change_model_params and "descriptor" not in last_model_params: - raise ValueError( - "Cannot use --use-pretrain-script: the pretrained model does not " - "contain full model params. If finetuning from a .pte file, " - "re-freeze it with the latest code so that model_def_script is embedded." - ) - + # For multi-task pretrained, check inside model_dict + if "model_dict" not in last_model_params or "descriptor" not in next( + iter(last_model_params["model_dict"].values()) + ): + raise ValueError( + "Cannot use --use-pretrain-script: the pretrained model does not " + "contain full model params. If finetuning from a .pte file, " + "re-freeze it with the latest code so that model_def_script is embedded." + ) + + multi_task = "model_dict" in model_config finetune_from_multi_task = "model_dict" in last_model_params - - # pt_expt is single-task only - if model_branch == "" and "finetune_head" in model_config: - model_branch = model_config["finetune_head"] - model_config, finetune_rule = get_finetune_rule_single( - model_config, - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch="Default", - model_branch_from=model_branch, - change_model_params=change_model_params, - ) - finetune_links: dict[str, FinetuneRuleItem] = {"Default": finetune_rule} + finetune_links: dict[str, FinetuneRuleItem] = {} + + if not multi_task: + # Single-task target + if model_branch == "" and "finetune_head" in model_config: + model_branch = model_config["finetune_head"] + model_config, finetune_rule = get_finetune_rule_single( + model_config, + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch="Default", + model_branch_from=model_branch, + change_model_params=change_model_params, + ) + finetune_links["Default"] = finetune_rule + else: + # Multi-task target — mirrors PT's logic + if model_branch != "": + raise ValueError( + "Multi-task fine-tuning does not support command-line branches chosen! " + "Please define the 'finetune_head' in each model params!" + ) + if not finetune_from_multi_task: + pretrained_keys = ["Default"] + else: + pretrained_keys = list(last_model_params["model_dict"].keys()) + for model_key in model_config["model_dict"]: + resuming = False + if ( + "finetune_head" in model_config["model_dict"][model_key] + and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" + ): + pretrained_key = model_config["model_dict"][model_key]["finetune_head"] + if pretrained_key not in pretrained_keys: + raise ValueError( + f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model! " + f"Available heads are: {list(pretrained_keys)}" + ) + model_branch_from = pretrained_key + elif ( + "finetune_head" not in model_config["model_dict"][model_key] + and model_key in pretrained_keys + ): + # resume — no finetune + model_branch_from = model_key + resuming = True + else: + # new branch or RANDOM → random fitting + model_branch_from = "RANDOM" + model_config["model_dict"][model_key], finetune_rule = ( + get_finetune_rule_single( + model_config["model_dict"][model_key], + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch=model_key, + model_branch_from=model_branch_from, + change_model_params=change_model_params, + ) + ) + finetune_links[model_key] = finetune_rule + finetune_links[model_key].resuming = resuming return model_config, finetune_links diff --git a/deepmd/pt_expt/utils/multi_task.py b/deepmd/pt_expt/utils/multi_task.py new file mode 100644 index 0000000000..a4600d5ebb --- /dev/null +++ b/deepmd/pt_expt/utils/multi_task.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.fitting import ( + BaseFitting, +) + + +def preprocess_shared_params( + model_config: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Preprocess the model params for multitask model, and generate the links dict for further sharing. + + Args: + model_config: Model params of multitask model. + + Returns + ------- + model_config: Preprocessed model params of multitask model. + Those string names are replaced with real params in `shared_dict` of model params. + shared_links: Dict of link infos for further sharing. + Each item, whose key must be in `shared_dict`, is a dict with following keys: + - "type": The real class type of this item. + - "links": List of shared settings, each sub-item is a dict with following keys: + - "model_key": Model key in the `model_dict` to share this item. + - "shared_type": Type of this shard item. + - "shared_level": Shared level (int) of this item in this model. + Lower for more params to share, 0 means to share all params in this item. + This list are sorted by "shared_level". + """ + assert "model_dict" in model_config, "only multi-task model can use this method!" + supported_types = ["type_map", "descriptor", "fitting_net"] + shared_dict = model_config.get("shared_dict", {}) + shared_links = {} + type_map_keys = [] + + def replace_one_item( + params_dict: dict[str, Any], + key_type: str, + key_in_dict: str, + suffix: str = "", + index: int | None = None, + ) -> None: + shared_type = key_type + shared_key = key_in_dict + shared_level = 0 + if ":" in key_in_dict: + shared_key = key_in_dict.split(":")[0] + shared_level = int(key_in_dict.split(":")[1]) + assert shared_key in shared_dict, ( + f"Appointed {shared_type} {shared_key} are not in the shared_dict! Please check the input params." + ) + if index is None: + params_dict[shared_type] = deepcopy(shared_dict[shared_key]) + else: + params_dict[index] = deepcopy(shared_dict[shared_key]) + if shared_type == "type_map": + if key_in_dict not in type_map_keys: + type_map_keys.append(key_in_dict) + else: + if shared_key not in shared_links: + class_name = get_class_name(shared_type, shared_dict[shared_key]) + shared_links[shared_key] = {"type": class_name, "links": []} + link_item = { + "model_key": model_key, + "shared_type": shared_type + suffix, + "shared_level": shared_level, + } + shared_links[shared_key]["links"].append(link_item) + + for model_key in model_config["model_dict"]: + model_params_item = model_config["model_dict"][model_key] + for item_key in model_params_item: + if item_key in supported_types: + item_params = model_params_item[item_key] + if isinstance(item_params, str): + replace_one_item(model_params_item, item_key, item_params) + elif item_params.get("type", "") == "hybrid": + for ii, hybrid_item in enumerate(item_params["list"]): + if isinstance(hybrid_item, str): + replace_one_item( + model_params_item[item_key]["list"], + item_key, + hybrid_item, + suffix=f"_hybrid_{ii}", + index=ii, + ) + for shared_key in shared_links: + shared_links[shared_key]["links"] = sorted( + shared_links[shared_key]["links"], + key=lambda x: ( + x["shared_level"] + - ("spin" in model_config["model_dict"][x["model_key"]]) * 100 + ), + ) + # little trick to make spin models in the front to be the base models, + # because its type embeddings are more general. + assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" + return model_config, shared_links + + +def get_class_name(item_key: str, item_params: dict[str, Any]) -> type: + if item_key == "descriptor": + return BaseDescriptor.get_class_by_type(item_params.get("type", "se_e2_a")) + elif item_key == "fitting_net": + return BaseFitting.get_class_by_type(item_params.get("type", "ener")) + else: + raise RuntimeError(f"Unknown class_name type {item_key}") diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 1629ecb83a..adef443de9 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import math from typing import ( Any, ClassVar, @@ -182,6 +183,14 @@ def _torch_activation(x: torch.Tensor, name: str) -> torch.Tensor: return torch.sigmoid(x) elif name == "silu": return torch.nn.functional.silu(x) + elif name.startswith("silut") or name.startswith("custom_silu"): + threshold = float(name.split(":")[-1]) if ":" in name else 3.0 + sig_t = 1.0 / (1.0 + math.exp(-threshold)) + slope = sig_t + threshold * sig_t * (1.0 - sig_t) + const = threshold * sig_t + silu = x * torch.sigmoid(x) + tanh_branch = torch.tanh(slope * (x - threshold)) + const + return torch.where(x < threshold, silu, tanh_branch) elif name in ("none", "linear"): return x else: diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py index 7867fee874..af58d12790 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa2.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -10,6 +10,9 @@ RepformerArgs, RepinitArgs, ) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers, +) from ...seed import ( GLOBAL_SEED, @@ -69,3 +72,36 @@ def test_self_consistency( for ii in [0, 1, 2, 3, 4]: np.testing.assert_equal(mm0[ii].shape, desired_shape[ii]) np.testing.assert_allclose(mm0[ii], mm1[ii]) + + +class TestDescrptBlockRepformersAccessors(unittest.TestCase): + def test_get_rcut_smth(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + ) + self.assertEqual(block.get_rcut_smth(), 5.0) + + def test_get_env_protection(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + env_protection=1.0, + ) + self.assertEqual(block.get_env_protection(), 1.0) + + def test_get_env_protection_default(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + ) + self.assertEqual(block.get_env_protection(), 0.0) diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 31351d4a9d..b46319e338 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, parameterized, ) @@ -29,6 +30,13 @@ from deepmd.pt.utils.utils import ( to_torch_tensor, ) +if INSTALLED_PT_EXPT: + import torch + + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE + from deepmd.pt_expt.utils.network import ( + _torch_activation, + ) if INSTALLED_TF: from deepmd.tf.common import get_activation_func as get_activation_fn_tf from deepmd.tf.env import ( @@ -98,3 +106,54 @@ def test_pd_consistent_with_ref(self): ActivationFn_pd(self.activation)(to_paddle_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") + def test_pt_expt_consistent_with_ref(self) -> None: + if INSTALLED_PT_EXPT: + x = torch.tensor( + self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE + ) + test = _torch_activation(x, self.activation).detach().numpy() + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + +@parameterized( + ( + "silut", # default threshold 3.0 + "silut:3.0", # explicit threshold 3.0 + "silut:10.0", # large threshold + "custom_silu:5.0", # alias + ), +) +class TestSilutVariantsConsistent(unittest.TestCase): + """Cross-backend consistency for silut with different thresholds.""" + + def setUp(self) -> None: + (self.activation,) = self.param + # Parse threshold to build input that covers both branches + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + rng = np.random.default_rng(GLOBAL_SEED) + # Values below threshold (silu branch) and above threshold (tanh branch) + below = rng.uniform(-threshold - 5, threshold - 0.1, size=(5, 10)) + above = rng.uniform(threshold + 0.1, threshold + 20, size=(5, 10)) + self.random_input = np.concatenate([below, above], axis=0) + self.ref = get_activation_fn_dp(self.activation)(self.random_input) + + @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") + def test_pt_consistent_with_ref(self) -> None: + if INSTALLED_PT: + test = torch_to_numpy( + ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) + ) + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") + def test_pt_expt_consistent_with_ref(self) -> None: + if INSTALLED_PT_EXPT: + x = torch.tensor( + self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE + ) + test = _torch_activation(x, self.activation).detach().numpy() + np.testing.assert_allclose(self.ref, test, atol=1e-10) diff --git a/source/tests/pt/test_fitting_stat.py b/source/tests/pt/test_fitting_stat.py index 7807523221..80d213bcad 100644 --- a/source/tests/pt/test_fitting_stat.py +++ b/source/tests/pt/test_fitting_stat.py @@ -280,7 +280,7 @@ def test_sharefitting_with_fparam(self): self.config["training"]["data_dict"]["model_2"]["validation_data"][ "systems" ] = self.data_file_single - self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 100 + self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 80 self.config["model"], self.shared_links = preprocess_shared_params( self.config["model"] @@ -391,7 +391,7 @@ def test_sharefitting_using_default_fparam(self): ] = self.data_file data_stat_protect = 5e-3 self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 3 - self.config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 100 + self.config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 80 self.config["model"]["model_dict"]["model_1"]["data_stat_protect"] = ( data_stat_protect ) diff --git a/source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py b/source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py new file mode 100644 index 0000000000..a8f420b2db --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py @@ -0,0 +1,1328 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for probability-weighted stat merging in descriptor share_params.""" + +from typing import ( + ClassVar, +) + +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.dpmodel.utils.env_mat_stat import ( + EnvMatStatSe, + merge_env_stat, +) +from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt_expt.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt_expt.descriptor.dpa3 import ( + DescrptDPA3, +) +from deepmd.pt_expt.descriptor.hybrid import ( + DescrptHybrid, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt_expt.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +def _make_stats(ntypes: int, last_dim: int, rng: np.random.Generator) -> dict: + """Create synthetic StatItem stats for an env mat descriptor. + + The stats dict has keys "r_{i}" and optionally "a_{i}" for each type, + matching the EnvMatStatSe convention. + """ + stats = {} + for ti in range(ntypes): + # Use moderate values to avoid zero-division + n = rng.uniform(100, 500) + s = rng.uniform(-10, 10) + sq = s**2 / n + rng.uniform(0.01, 1.0) # ensure variance > 0 + stats[f"r_{ti}"] = StatItem(number=n, sum=s, squared_sum=sq * n) + if last_dim == 4: + n_a = rng.uniform(100, 500) + s_a = rng.uniform(-10, 10) + sq_a = s_a**2 / n_a + rng.uniform(0.01, 1.0) + stats[f"a_{ti}"] = StatItem(number=n_a, sum=s_a, squared_sum=sq_a * n_a) + return stats + + +def _compute_expected_buffers(descriptor, merged_stats, last_dim): + """Compute expected mean/stddev from merged stats using EnvMatStatSe.""" + env_stat = EnvMatStatSe(descriptor) + env_stat.stats = merged_stats + mean, stddev = env_stat() + return mean, stddev + + +def _merge_stats(base_stats, link_stats, model_prob): + """Manually merge stats dicts.""" + merged = {} + for kk in base_stats: + merged[kk] = base_stats[kk] + link_stats[kk] * model_prob + return merged + + +class TestStatMergeSeA: + """Test stat merging for se_e2_a descriptor.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def _make_descriptor(self, seed): + return DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=seed).to( + self.device + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0, 0.1]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged davg/dstd match manually computed values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = base_stats + dd_link.stats = link_stats + + # Set initial davg/dstd on base + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + # Compute expected + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base, merged_stats, self.last_dim + ) + + # share_params with stat merging + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + # Verify buffers match expected + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + # Verify stats updated for chaining + for kk in merged_stats: + assert abs(dd_base.stats[kk].number - merged_stats[kk].number) < 1e-10 + assert abs(dd_base.stats[kk].sum - merged_stats[kk].sum) < 1e-10 + + def test_buffers_aliased(self) -> None: + """After share_params, link buffers should be aliased to base.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.5, resume=False) + + for key in dd_base._buffers: + assert dd_link._buffers[key] is dd_base._buffers[key], ( + f"Buffer {key} not aliased" + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging and preserve original buffers.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + original_davg = dd_base.davg.clone() + original_dstd = dd_base.dstd.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + # Buffers should be unchanged + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), + original_dstd.detach().cpu().numpy(), + ) + + def test_none_stats_skips_merge(self) -> None: + """When stats is None, merging should be silently skipped.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + # stats is not set (default None) + assert getattr(dd_base, "stats", None) is None + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + original_davg = dd_base.davg.clone() + + # Should not raise + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + # davg should be unchanged (merge was skipped) + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestStatMergeSeR: + """Test stat merging for se_r descriptor.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 100) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 1 + + def _make_descriptor(self, seed): + return DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=seed).to( + self.device + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged davg/dstd match manually computed values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + original_davg = dd_base.davg.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestStatMergeSeT: + """Test stat merging for se_t descriptor.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 200) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def _make_descriptor(self, seed): + return DescrptSeT(self.rcut, self.rcut_smth, self.sel, seed=seed).to( + self.device + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged davg/dstd match manually computed values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + original_davg = dd_base.davg.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA1: + """Test stat merging for DPA1 descriptor (se_atten block has mean/stddev).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel = 7 + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 300) + self.ntypes = 2 + self.nnei = self.sel + self.last_dim = 4 + + def _make_descriptor(self, seed): + return DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + seed=seed, + ).to(self.device) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged mean/stddev on se_atten block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.se_atten.stats = base_stats + dd_link.se_atten.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_atten.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_atten.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.se_atten, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.se_atten.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.se_atten.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging on se_atten block.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.se_atten.stats = base_stats + dd_link.se_atten.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_atten.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_atten.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.se_atten.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.se_atten.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + def test_level1_no_merge(self) -> None: + """Level 1 shares type_embedding only, no stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.se_atten.stats = base_stats + dd_link.se_atten.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dd_base.se_atten.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.se_atten.mean.clone() + + dd_link.share_params(dd_base, shared_level=1, model_prob=0.6, resume=False) + + # type_embedding shared + assert dd_link._modules["type_embedding"] is dd_base._modules["type_embedding"] + # se_atten NOT shared + assert dd_link._modules["se_atten"] is not dd_base._modules["se_atten"] + # stats unchanged + np.testing.assert_allclose( + dd_base.se_atten.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA2: + """Test stat merging for DPA2 descriptor (repinit and repformers blocks).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel_mix: ClassVar = [7] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 350) + self.ntypes = 2 + self.nnei = sum(self.sel_mix) + self.last_dim = 4 + + def _make_descriptor(self, seed): + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="strip", + set_davg_zero=False, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=self.nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=False, + update_g1_has_grrg=False, + update_g1_has_attn=False, + update_g2_has_g1g1=False, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=True, + update_style="res_residual", + set_davg_zero=False, + ) + dd = DescrptDPA2( + self.ntypes, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + seed=seed, + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_repinit(self, model_prob) -> None: + """Verify merged mean/stddev on repinit block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repinit.stats = base_stats + dd_link.repinit.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.repinit.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repinit.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repinit, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repinit.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.repinit.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_repformers(self, model_prob) -> None: + """Verify merged mean/stddev on repformers block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + nnei_repformers = self.nnei // 2 + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repformers.stats = base_stats + dd_link.repformers.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, nnei_repformers, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, nnei_repformers, self.last_dim)) + ) + dd_base.repformers.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repformers.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repformers, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repformers.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.repformers.stddev.detach().cpu().numpy(), + expected_stddev, + rtol=1e-10, + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging on all blocks.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats_ri = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats_ri = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit.stats = base_stats_ri + dd_link.repinit.stats = link_stats_ri + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dd_base.repinit.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repinit.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.repinit.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + def test_level1_no_merge(self) -> None: + """Level 1 shares type_embedding only, no stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit.stats = base_stats + dd_link.repinit.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dd_base.repinit.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repinit.mean.clone() + + dd_link.share_params(dd_base, shared_level=1, model_prob=0.6, resume=False) + + # type_embedding shared + assert dd_link._modules["type_embedding"] is dd_base._modules["type_embedding"] + # repinit NOT shared + assert dd_link._modules["repinit"] is not dd_base._modules["repinit"] + # stats unchanged + np.testing.assert_allclose( + dd_base.repinit.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA3: + """Test stat merging for DPA3 descriptor (repflows block has mean/stddev).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel = 7 + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 400) + self.ntypes = 2 + self.nnei = self.sel + self.last_dim = 4 + + def _make_descriptor(self, seed, fix_stat_std=0.0): + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=self.sel, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=self.sel - 1, + axis_neuron=4, + update_angle=True, + update_style="res_residual", + smooth_edge_update=True, + fix_stat_std=fix_stat_std, + ) + dd = DescrptDPA3( + self.ntypes, + repflow=repflow, + seed=seed, + ).to(self.device) + # Override set_davg_zero for testing (default True in repflows) + if fix_stat_std == 0.0: + dd.repflows.set_davg_zero = False + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged mean/stddev on repflows block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repflows.stats = base_stats + dd_link.repflows.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.repflows.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repflows.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repflows, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repflows.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.repflows.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_default_config_skips_merge(self) -> None: + """Default DPA3 has set_davg_zero=True and set_stddev_constant=True, so merge is no-op.""" + dd_base = self._make_descriptor(GLOBAL_SEED, fix_stat_std=0.3) + dd_link = self._make_descriptor(GLOBAL_SEED + 1, fix_stat_std=0.3) + # Restore defaults + dd_base.repflows.set_davg_zero = True + dd_link.repflows.set_davg_zero = True + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repflows.stats = base_stats + dd_link.repflows.stats = link_stats + + original_mean = dd_base.repflows.mean.clone() + original_stddev = dd_base.repflows.stddev.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + # Buffers should be unchanged + np.testing.assert_allclose( + dd_base.repflows.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + np.testing.assert_allclose( + dd_base.repflows.stddev.detach().cpu().numpy(), + original_stddev.detach().cpu().numpy(), + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repflows.stats = base_stats + dd_link.repflows.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.repflows.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repflows.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repflows.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.repflows.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeSeTTebd: + """Test stat merging for se_t_tebd descriptor (se_ttebd block has mean/stddev).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 500) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def _make_descriptor(self, seed): + dd = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + set_davg_zero=False, + seed=seed, + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged mean/stddev on se_ttebd block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.se_ttebd.stats = base_stats + dd_link.se_ttebd.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_ttebd.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_ttebd.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.se_ttebd, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.se_ttebd.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.se_ttebd.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.se_ttebd.stats = base_stats + dd_link.se_ttebd.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_ttebd.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_ttebd.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.se_ttebd.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.se_ttebd.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA2ThreeBody: + """Test stat merging for DPA2 descriptor with use_three_body=True.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel_mix: ClassVar = [7] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 600) + self.ntypes = 2 + self.nnei = sum(self.sel_mix) + self.last_dim = 4 + self.three_body_sel = 5 + self.three_body_rcut = self.rcut + self.three_body_rcut_smth = self.rcut_smth + + def _make_descriptor(self, seed): + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="strip", + set_davg_zero=False, + use_three_body=True, + three_body_sel=self.three_body_sel, + three_body_rcut=self.three_body_rcut, + three_body_rcut_smth=self.three_body_rcut_smth, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=self.nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=False, + update_g1_has_grrg=False, + update_g1_has_attn=False, + update_g2_has_g1g1=False, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=True, + update_style="res_residual", + set_davg_zero=False, + ) + dd = DescrptDPA2( + self.ntypes, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + seed=seed, + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_three_body(self, model_prob) -> None: + """Verify merged mean/stddev on repinit_three_body block.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + assert dd_base.use_three_body + assert dd_base.repinit_three_body is not None + + # repinit_three_body is a DescrptBlockSeTTebd with mean/stddev + nnei_3b = self.three_body_sel + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repinit_three_body.stats = base_stats + dd_link.repinit_three_body.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, nnei_3b, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, nnei_3b, self.last_dim)) + ) + dd_base.repinit_three_body.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repinit_three_body.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repinit_three_body, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repinit_three_body.mean.detach().cpu().numpy(), + expected_mean, + rtol=1e-10, + ) + np.testing.assert_allclose( + dd_base.repinit_three_body.stddev.detach().cpu().numpy(), + expected_stddev, + rtol=1e-10, + ) + + def test_three_body_aliased(self) -> None: + """After share_params, repinit_three_body modules should be aliased.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit_three_body.stats = base_stats + dd_link.repinit_three_body.stats = link_stats + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + assert ( + dd_link._modules["repinit_three_body"] + is dd_base._modules["repinit_three_body"] + ) + + def test_resume_skips_three_body_merge(self) -> None: + """resume=True should skip stat merging on three-body block.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + nnei_3b = self.three_body_sel + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit_three_body.stats = base_stats + dd_link.repinit_three_body.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, nnei_3b, self.last_dim)) + dd_base.repinit_three_body.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repinit_three_body.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.repinit_three_body.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeHybrid: + """Test stat merging for hybrid descriptor (passes model_prob to sub-descriptors).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 700) + self.ntypes = 2 + self.nnei = sum(self.sel) + + def _make_descriptor(self, seed): + """Create a hybrid descriptor with se_e2_a (last_dim=4) + se_r (last_dim=1).""" + dd = DescrptHybrid( + list=[ + DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=seed), + DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=seed + 10), + ], + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_sub_descriptors(self, model_prob) -> None: + """Verify merged davg/dstd on each sub-descriptor match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + # SeA sub-descriptor (last_dim=4) + sea_base_stats = _make_stats(self.ntypes, 4, self.rng) + sea_link_stats = _make_stats(self.ntypes, 4, self.rng) + dd_base.descrpt_list[0].stats = sea_base_stats + dd_link.descrpt_list[0].stats = sea_link_stats + + davg0_sea = self.rng.normal(size=(self.ntypes, self.nnei, 4)) + dstd0_sea = 0.1 + np.abs(self.rng.normal(size=(self.ntypes, self.nnei, 4))) + dd_base.descrpt_list[0].davg = torch.tensor( + davg0_sea, dtype=torch.float64, device=self.device + ) + dd_base.descrpt_list[0].dstd = torch.tensor( + dstd0_sea, dtype=torch.float64, device=self.device + ) + + # SeR sub-descriptor (last_dim=1) + ser_base_stats = _make_stats(self.ntypes, 1, self.rng) + ser_link_stats = _make_stats(self.ntypes, 1, self.rng) + dd_base.descrpt_list[1].stats = ser_base_stats + dd_link.descrpt_list[1].stats = ser_link_stats + + davg0_ser = self.rng.normal(size=(self.ntypes, self.nnei, 1)) + dstd0_ser = 0.1 + np.abs(self.rng.normal(size=(self.ntypes, self.nnei, 1))) + dd_base.descrpt_list[1].davg = torch.tensor( + davg0_ser, dtype=torch.float64, device=self.device + ) + dd_base.descrpt_list[1].dstd = torch.tensor( + dstd0_ser, dtype=torch.float64, device=self.device + ) + + # Compute expected for SeA + merged_sea = _merge_stats(sea_base_stats, sea_link_stats, model_prob) + exp_mean_sea, exp_std_sea = _compute_expected_buffers( + dd_base.descrpt_list[0], merged_sea, 4 + ) + + # Compute expected for SeR + merged_ser = _merge_stats(ser_base_stats, ser_link_stats, model_prob) + exp_mean_ser, exp_std_ser = _compute_expected_buffers( + dd_base.descrpt_list[1], merged_ser, 1 + ) + + # share_params on hybrid passes model_prob to each sub-descriptor + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + # Verify SeA sub-descriptor buffers + np.testing.assert_allclose( + dd_base.descrpt_list[0].davg.detach().cpu().numpy(), + exp_mean_sea, + rtol=1e-10, + ) + np.testing.assert_allclose( + dd_base.descrpt_list[0].dstd.detach().cpu().numpy(), + exp_std_sea, + rtol=1e-10, + ) + + # Verify SeR sub-descriptor buffers + np.testing.assert_allclose( + dd_base.descrpt_list[1].davg.detach().cpu().numpy(), + exp_mean_ser, + rtol=1e-10, + ) + np.testing.assert_allclose( + dd_base.descrpt_list[1].dstd.detach().cpu().numpy(), + exp_std_ser, + rtol=1e-10, + ) + + def test_sub_descriptors_aliased(self) -> None: + """After share_params, sub-descriptor modules should be aliased.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + # Populate stats to avoid None-guard early return + for i in range(2): + last_dim = 4 if i == 0 else 1 + dd_base.descrpt_list[i].stats = _make_stats(self.ntypes, last_dim, self.rng) + dd_link.descrpt_list[i].stats = _make_stats(self.ntypes, last_dim, self.rng) + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + for i in range(2): + assert ( + dd_link.descrpt_list[i].davg.data_ptr() + == dd_base.descrpt_list[i].davg.data_ptr() + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging on all sub-descriptors.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + sea_base_stats = _make_stats(self.ntypes, 4, self.rng) + sea_link_stats = _make_stats(self.ntypes, 4, self.rng) + dd_base.descrpt_list[0].stats = sea_base_stats + dd_link.descrpt_list[0].stats = sea_link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, 4)) + dd_base.descrpt_list[0].davg = torch.tensor( + davg0, dtype=torch.float64, device=self.device + ) + original_davg = dd_base.descrpt_list[0].davg.clone() + + # Need stats on all sub-descriptors to avoid None guard + dd_base.descrpt_list[1].stats = _make_stats(self.ntypes, 1, self.rng) + dd_link.descrpt_list[1].stats = _make_stats(self.ntypes, 1, self.rng) + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.descrpt_list[0].davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestMergeEnvStatUnit: + """Unit tests for the merge_env_stat function directly.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 600) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def test_merge_produces_correct_stats(self) -> None: + """merge_env_stat should compute merged = base + link * model_prob.""" + dd_base = DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + + dd_link = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1 + ).to(self.device) + dd_link.stats = link_stats + + model_prob = 0.3 + merge_env_stat(dd_base, dd_link, model_prob) + + for kk in base_stats: + expected = base_stats[kk] + link_stats[kk] * model_prob + assert abs(dd_base.stats[kk].number - expected.number) < 1e-10 + assert abs(dd_base.stats[kk].sum - expected.sum) < 1e-10 + assert abs(dd_base.stats[kk].squared_sum - expected.squared_sum) < 1e-10 + + def test_chaining_three_models(self) -> None: + """Merging stats from 3 models should accumulate correctly.""" + dd_base = DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + dd_link1 = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1 + ).to(self.device) + dd_link2 = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 2 + ).to(self.device) + + stats_base = _make_stats(self.ntypes, self.last_dim, self.rng) + stats_link1 = _make_stats(self.ntypes, self.last_dim, self.rng) + stats_link2 = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = stats_base + dd_link1.stats = stats_link1 + dd_link2.stats = stats_link2 + + prob1, prob2 = 0.5, 0.3 + + merge_env_stat(dd_base, dd_link1, prob1) + merge_env_stat(dd_base, dd_link2, prob2) + + for kk in stats_base: + expected = ( + stats_base[kk] + stats_link1[kk] * prob1 + stats_link2[kk] * prob2 + ) + assert abs(dd_base.stats[kk].number - expected.number) < 1e-10 + assert abs(dd_base.stats[kk].sum - expected.sum) < 1e-10 + + def test_set_davg_zero_respected(self) -> None: + """When set_davg_zero=True, davg should remain zero after merging.""" + dd_base = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED, set_davg_zero=True + ).to(self.device) + dd_link = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + seed=GLOBAL_SEED + 1, + set_davg_zero=True, + ).to(self.device) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + original_davg = dd_base.davg.clone() + merge_env_stat(dd_base, dd_link, 0.6) + + # davg should stay zero + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + # but dstd should be updated + assert dd_base.stats is not base_stats # stats dict replaced diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index e90c67bc82..2662997a87 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -239,3 +239,46 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED, + ).to(self.device) + dd1 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED + 1, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.se_atten.stddev = torch.tensor( + dstd0, dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["se_atten"] is dd0._modules["se_atten"] + elif shared_level == 1: + assert dd1._modules["se_atten"] is not dd0._modules["se_atten"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/descriptor/test_dpa3.py b/source/tests/pt_expt/descriptor/test_dpa3.py index ecc94d24f5..ef4b479724 100644 --- a/source/tests/pt_expt/descriptor/test_dpa3.py +++ b/source/tests/pt_expt/descriptor/test_dpa3.py @@ -260,3 +260,54 @@ def fn(coord_ext, atype_ext, nlist, mapping): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=True, + update_style="res_residual", + update_residual_init="const", + smooth_edge_update=True, + ) + + dd0 = DescrptDPA3( + self.nt, repflow=repflow, exclude_types=[], seed=GLOBAL_SEED + ).to(self.device) + dd1 = DescrptDPA3( + self.nt, repflow=repflow, exclude_types=[], seed=GLOBAL_SEED + 1 + ).to(self.device) + dd0.repflows.mean = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.repflows.stddev = torch.tensor( + dstd0, dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["repflows"] is dd0._modules["repflows"] + elif shared_level == 1: + assert dd1._modules["repflows"] is not dd0._modules["repflows"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/descriptor/test_hybrid.py b/source/tests/pt_expt/descriptor/test_hybrid.py index a3c673d774..5fa8970bf1 100644 --- a/source/tests/pt_expt/descriptor/test_hybrid.py +++ b/source/tests/pt_expt/descriptor/test_hybrid.py @@ -231,3 +231,56 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + def test_share_params(self) -> None: + """share_params level 0: recursively shares all sub-descriptors.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg4 = rng.normal(size=(self.nt, nnei, 4)) + dstd4 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptHybrid( + list=[ + DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED), + DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED), + ] + ).to(self.device) + dd1 = DescrptHybrid( + list=[ + DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1), + DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1), + ] + ).to(self.device) + + # set stats on dd0's sub-descriptors + dd0.descrpt_list[0].davg = torch.tensor( + davg4, dtype=torch.float64, device=self.device + ) + dd0.descrpt_list[0].dstd = torch.tensor( + dstd4, dtype=torch.float64, device=self.device + ) + dd0.descrpt_list[1].davg = torch.tensor( + davg4[..., :1], dtype=torch.float64, device=self.device + ) + dd0.descrpt_list[1].dstd = torch.tensor( + dstd4[..., :1], dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=0) + + # each sub-descriptor's modules/buffers are shared + for ii in range(len(dd0.descrpt_list)): + for key in dd0.descrpt_list[ii]._modules: + assert ( + dd1.descrpt_list[ii]._modules[key] + is dd0.descrpt_list[ii]._modules[key] + ) + for key in dd0.descrpt_list[ii]._buffers: + assert ( + dd1.descrpt_list[ii]._buffers[key] + is dd0.descrpt_list[ii]._buffers[key] + ) + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=1) diff --git a/source/tests/pt_expt/descriptor/test_se_atten_v2.py b/source/tests/pt_expt/descriptor/test_se_atten_v2.py index 326a78acad..cc86c1600b 100644 --- a/source/tests/pt_expt/descriptor/test_se_atten_v2.py +++ b/source/tests/pt_expt/descriptor/test_se_atten_v2.py @@ -234,3 +234,46 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED, + ).to(self.device) + dd1 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED + 1, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.se_atten.stddev = torch.tensor( + dstd0, dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["se_atten"] is dd0._modules["se_atten"] + elif shared_level == 1: + assert dd1._modules["se_atten"] is not dd0._modules["se_atten"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index cde3295e7a..9056c9f308 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -216,3 +216,43 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + def test_share_params(self) -> None: + """share_params level 0: all modules and buffers are shared.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 1)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 1))) + + dd0 = DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + dd1 = DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1).to( + self.device + ) + dd0.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + dd1.share_params(dd0, shared_level=0) + + # all modules and buffers are shared (same object) + for key in dd0._modules: + assert dd1._modules[key] is dd0._modules[key] + for key in dd0._buffers: + assert dd1._buffers[key] is dd0._buffers[key] + + # forward pass produces identical output + inputs = ( + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + rd0 = dd0(*inputs)[0] + rd1 = dd1(*inputs)[0] + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd1.detach().cpu().numpy() + ) + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=1) diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py index bb1f9a4b3f..ed71f1e0ed 100644 --- a/source/tests/pt_expt/descriptor/test_se_t.py +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -220,3 +220,43 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + def test_share_params(self) -> None: + """share_params level 0: all modules and buffers are shared.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptSeT(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + dd1 = DescrptSeT(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1).to( + self.device + ) + dd0.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + dd1.share_params(dd0, shared_level=0) + + # all modules and buffers are shared (same object) + for key in dd0._modules: + assert dd1._modules[key] is dd0._modules[key] + for key in dd0._buffers: + assert dd1._buffers[key] is dd0._buffers[key] + + # forward pass produces identical output + inputs = ( + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + rd0 = dd0(*inputs)[0] + rd1 = dd1(*inputs)[0] + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd1.detach().cpu().numpy() + ) + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=1) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index 30808f5070..41643b41ef 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -251,3 +251,34 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptSeTTebd( + self.rcut, self.rcut_smth, self.sel, self.nt, seed=GLOBAL_SEED + ).to(self.device) + dd1 = DescrptSeTTebd( + self.rcut, self.rcut_smth, self.sel, self.nt, seed=GLOBAL_SEED + 1 + ).to(self.device) + dd0.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["se_ttebd"] is dd0._modules["se_ttebd"] + elif shared_level == 1: + assert dd1._modules["se_ttebd"] is not dd0._modules["se_ttebd"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index dcb99dd324..038e0dcf27 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -1,18 +1,42 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os +import shutil +import tempfile import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) import numpy as np import torch +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, ) +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) from deepmd.pt_expt.fitting import ( EnergyFittingNet, ) from deepmd.pt_expt.utils import ( env, ) +from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): @@ -74,6 +98,58 @@ def _brute_aparam_pt(data, ndim): return avg, std +def _get_weighted_fitting_stat( + model_prob: list, *stat_arrays, protection: float +) -> tuple[np.ndarray, np.ndarray]: + """Compute probability-weighted fparam avg and std (matching PT).""" + n_arrays = len(stat_arrays) + assert len(model_prob) == n_arrays + nframes = [stat.shape[0] for stat in stat_arrays] + sums = [stat.sum(axis=0) for stat in stat_arrays] + squared_sums = [(stat**2).sum(axis=0) for stat in stat_arrays] + weighted_sum = sum(model_prob[i] * sums[i] for i in range(n_arrays)) + total_weighted_frames = sum(model_prob[i] * nframes[i] for i in range(n_arrays)) + weighted_avg = weighted_sum / total_weighted_frames + weighted_square_sum = sum(model_prob[i] * squared_sums[i] for i in range(n_arrays)) + weighted_square_avg = weighted_square_sum / total_weighted_frames + weighted_std = np.sqrt(weighted_square_avg - weighted_avg**2) + weighted_std = np.where(weighted_std < protection, protection, weighted_std) + return weighted_avg, weighted_std + + +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0") +_PT_DATA_NO_FPARAM = str( + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1" +) +_PT_DATA_SINGLE = str( + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single" +) + +_descriptor_se_e2_a = { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, +} + +_fitting_net = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, +} + + +def _skip_if_no_data() -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + + class TestEnerFittingStat(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE @@ -125,3 +201,468 @@ def test(self) -> None: np.testing.assert_almost_equal(frefs_inv, fparam_inv_std_np) np.testing.assert_almost_equal(arefa, aparam_avg_np) np.testing.assert_almost_equal(arefs_inv, aparam_inv_std_np) + + +class TestMultiTaskFittingStat(unittest.TestCase): + """Test shared fitting stat (fparam_avg/fparam_inv_std) in multi-task. + + Corresponds to PT's TestMultiTaskFittingStat in test_fitting_stat.py. + Verifies: + 1. fparam stats are shared between models (same tensor values) + 2. stat file contents match raw data (number, sum, squared_sum) + 3. weighted stat computation matches model values + 4. case_embd with default_fparam works correctly + """ + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + if not os.path.isdir(_PT_DATA_SINGLE): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA_SINGLE}") + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_fitstat_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + self.stat_files = "se_e2_a_share_fit" + os.makedirs(self.stat_files, exist_ok=True) + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def _make_sharefit_config( + self, + *, + numb_fparam: int = 2, + numb_aparam: int = 0, + default_fparam: list | None = None, + dim_case_embd: int = 2, + model_keys: list[str] | None = None, + data_dirs: dict[str, str] | None = None, + model_probs: dict[str, float] | None = None, + ) -> dict: + """Build a multi-task config with shared fitting + fparam.""" + if model_keys is None: + model_keys = ["model_1", "model_2"] + if data_dirs is None: + data_dirs = dict.fromkeys(model_keys, _PT_DATA) + if model_probs is None: + model_probs = {mk: 1.0 / len(model_keys) for mk in model_keys} + + shared_fitting: dict = deepcopy(_fitting_net) + shared_fitting["numb_fparam"] = numb_fparam + if numb_aparam > 0: + shared_fitting["numb_aparam"] = numb_aparam + shared_fitting["dim_case_embd"] = dim_case_embd + if default_fparam is not None: + shared_fitting["default_fparam"] = default_fparam + + shared_dict: dict = { + "my_type_map": ["O", "H"], + "my_descriptor": deepcopy(_descriptor_se_e2_a), + "my_fitting": shared_fitting, + } + + model_dict = {} + loss_dict = {} + data_dict = {} + for mk in model_keys: + model_dict[mk] = { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + } + loss_dict[mk] = { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + } + data_dict[mk] = { + "stat_file": f"{self.stat_files}/{mk}", + "training_data": { + "systems": [data_dirs[mk]], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dirs[mk]], + "batch_size": 1, + "numb_btch": 1, + }, + } + + config = { + "model": { + "shared_dict": shared_dict, + "model_dict": model_dict, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": loss_dict, + "training": { + "model_prob": model_probs, + "data_dict": data_dict, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + }, + } + return config + + def test_sharefitting_with_fparam(self) -> None: + """Shared fitting with fparam data: weighted fparam stat merging.""" + model_prob = [0.3, 0.7] + config = self._make_sharefit_config( + numb_fparam=2, + default_fparam=[1.0, 0.0], + data_dirs={"model_1": _PT_DATA, "model_2": _PT_DATA_SINGLE}, + model_probs={"model_1": model_prob[0], "model_2": model_prob[1]}, + ) + # data_0 has 80 frames; use data_stat_nbatch=100 to cover all frames + config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 80 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(deepcopy(config), shared_links=shared_links) + trainer.run() + + # fparam_avg and fparam_inv_std should be shared between models + multi_state_dict = trainer.wrapper.model.state_dict() + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_inv_std"], + ) + + # check fitting stat in stat_file is correct + fparam_stat_model1 = np.load(f"{self.stat_files}/model_1/O H/fparam") + fparam_stat_model2 = np.load(f"{self.stat_files}/model_2/O H/fparam") + fparam_data1 = np.load(os.path.join(_PT_DATA, "set.000", "fparam.npy")) + fparam_data2 = np.load(os.path.join(_PT_DATA_SINGLE, "set.000", "fparam.npy")) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 0], [fparam_data1.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 1], fparam_data1.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 2], (fparam_data1**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 0], [fparam_data2.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 1], fparam_data2.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 2], (fparam_data2**2).sum(axis=0) + ) + + # check shared fitting stat is computed correctly + weighted_avg, weighted_std = _get_weighted_fitting_stat( + model_prob, fparam_data1, fparam_data2, protection=1e-2 + ) + np.testing.assert_almost_equal( + weighted_avg, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"] + ), + ) + np.testing.assert_almost_equal( + 1.0 / weighted_std, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"] + ), + ) + + def test_sharefitting_using_default_fparam(self) -> None: + """3 models with dim_case_embd=3, default fparam, no fparam in data.""" + default_fparam = [1.0, 0.0] + model_prob = [0.1, 0.3, 0.6] + data_stat_protect = 5e-3 + config = self._make_sharefit_config( + numb_fparam=2, + default_fparam=default_fparam, + dim_case_embd=3, + model_keys=["model_1", "model_2", "model_3"], + data_dirs={ + "model_1": _PT_DATA_NO_FPARAM, + "model_2": _PT_DATA_SINGLE, + "model_3": _PT_DATA, + }, + model_probs={ + "model_1": model_prob[0], + "model_2": model_prob[1], + "model_3": model_prob[2], + }, + ) + # model_1 uses data without fparam → default_fparam is used + config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 3 + config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 80 + config["model"]["model_dict"]["model_1"]["data_stat_protect"] = ( + data_stat_protect + ) + config["model"]["model_dict"]["model_2"]["data_stat_protect"] = ( + data_stat_protect + ) + config["model"]["model_dict"]["model_3"]["data_stat_protect"] = ( + data_stat_protect + ) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(deepcopy(config), shared_links=shared_links) + trainer.run() + + # fparam_avg shared across all 3 models + multi_state_dict = trainer.wrapper.model.state_dict() + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_3.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_inv_std"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_3.atomic_model.fitting_net.fparam_inv_std"], + ) + + # check fitting stat in stat_file is correct + fparam_stat_model1 = np.load(f"{self.stat_files}/model_1/O H/fparam") + fparam_stat_model2 = np.load(f"{self.stat_files}/model_2/O H/fparam") + fparam_stat_model3 = np.load(f"{self.stat_files}/model_3/O H/fparam") + fparam_data1 = np.array([default_fparam]).repeat(3, axis=0) + fparam_data2 = np.load(os.path.join(_PT_DATA_SINGLE, "set.000", "fparam.npy")) + fparam_data3 = np.load(os.path.join(_PT_DATA, "set.000", "fparam.npy")) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 0], [fparam_data1.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 1], fparam_data1.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 2], (fparam_data1**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 0], [fparam_data2.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 1], fparam_data2.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 2], (fparam_data2**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 0], [fparam_data3.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 1], fparam_data3.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 2], (fparam_data3**2).sum(axis=0) + ) + + # check shared fitting stat is computed correctly + weighted_avg, weighted_std = _get_weighted_fitting_stat( + model_prob, + fparam_data1, + fparam_data2, + fparam_data3, + protection=data_stat_protect, + ) + np.testing.assert_almost_equal( + weighted_avg, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"] + ), + ) + np.testing.assert_almost_equal( + 1.0 / weighted_std, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"] + ), + ) + + # case_embd should be set on all 3 models + ce1 = trainer.wrapper.model["model_1"].atomic_model.fitting_net.case_embd + ce2 = trainer.wrapper.model["model_2"].atomic_model.fitting_net.case_embd + ce3 = trainer.wrapper.model["model_3"].atomic_model.fitting_net.case_embd + self.assertIsNotNone(ce1) + self.assertIsNotNone(ce2) + self.assertIsNotNone(ce3) + + # dim_case_embd=3 → each is a 3-element one-hot vector + self.assertEqual(ce1.shape[-1], 3) + self.assertEqual(ce2.shape[-1], 3) + self.assertEqual(ce3.shape[-1], 3) + + # Each should be one-hot + self.assertEqual(ce1.sum().item(), 1.0) + self.assertEqual(ce2.sum().item(), 1.0) + self.assertEqual(ce3.sum().item(), 1.0) + + # All three should be different + self.assertFalse(torch.equal(ce1, ce2)) + self.assertFalse(torch.equal(ce1, ce3)) + self.assertFalse(torch.equal(ce2, ce3)) + + # case_embd should NOT be shared in state_dict + for state_key in multi_state_dict: + if "case_embd" in state_key and "model_1" in state_key: + k2 = state_key.replace("model_1", "model_2") + k3 = state_key.replace("model_1", "model_3") + self.assertFalse( + torch.equal(multi_state_dict[state_key], multi_state_dict[k2]), + ) + self.assertFalse( + torch.equal(multi_state_dict[state_key], multi_state_dict[k3]), + ) + + def test_sharefitting_with_aparam(self) -> None: + """Weighted aparam stat merging in share_params (unit test). + + Directly tests the aparam branch in InvarFitting.share_params by + creating two fittings with different aparam stats and verifying that + share_params produces the correct probability-weighted merged result. + """ + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + ntypes = descrpt.get_ntypes() + dim_out = descrpt.get_dim_out() + + fit_base = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_aparam=3, seed=1 + ).to(env.DEVICE) + fit_link = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_aparam=3, seed=2 + ).to(env.DEVICE) + + # give both fittings different aparam stats + data_base = _make_fake_data_pt( + [10, 100], [5, 2], [0, 10, 100], [2, 0.4, 0.00001] + ) + data_link = _make_fake_data_pt([50], [8], [5, 20, 50], [1, 0.5, 0.01]) + fit_base.compute_input_stats(data_base, protection=1e-2) + fit_link.compute_input_stats(data_link, protection=1e-2) + + # record base's aparam_avg before share_params + orig_base_avg = fit_base.aparam_avg.clone() + + # share_params with model_prob=0.6 — should do weighted merging + model_prob = 0.6 + fit_link.share_params( + fit_base, shared_level=0, model_prob=model_prob, protection=1e-2 + ) + + # base's aparam_avg was UPDATED (weighted merging happened) + self.assertFalse( + torch.equal(fit_base.aparam_avg, orig_base_avg), + "aparam_avg should have changed after weighted merging", + ) + + # buffers are shared (same data_ptr) + self.assertEqual(fit_link.aparam_avg.data_ptr(), fit_base.aparam_avg.data_ptr()) + self.assertEqual( + fit_link.aparam_inv_std.data_ptr(), fit_base.aparam_inv_std.data_ptr() + ) + + # verify the merged stats match manual computation + # reconstruct raw aparam data from each fitting's stats + base_aparam_stats = fit_base.get_param_stats().get("aparam", []) + # the merged stats should have 3 StatItem objects + self.assertEqual(len(base_aparam_stats), 3) + + # manually compute the weighted average from raw data + # data_base has two systems: [10 natoms, 5 frames] + [100 natoms, 2 frames] + # data_link has one system: [50 natoms, 8 frames] + # aparam per system: reshape to (nframes * natoms, numb_aparam) + all_base = np.concatenate( + [d["aparam"].reshape(-1, 3) for d in data_base], axis=0 + ) + all_link = np.concatenate( + [d["aparam"].reshape(-1, 3) for d in data_link], axis=0 + ) + # weighted stat: base contributes with weight 1.0, link with model_prob + total_n = all_base.shape[0] + model_prob * all_link.shape[0] + weighted_sum = all_base.sum(axis=0) + model_prob * all_link.sum(axis=0) + weighted_avg = weighted_sum / total_n + weighted_sq_sum = (all_base**2).sum(axis=0) + model_prob * (all_link**2).sum( + axis=0 + ) + weighted_sq_avg = weighted_sq_sum / total_n + weighted_std = np.sqrt(weighted_sq_avg - weighted_avg**2) + weighted_std = np.where(weighted_std < 1e-2, 1e-2, weighted_std) + + aparam_avg_np = to_numpy_array(fit_base.aparam_avg) + aparam_inv_std_np = to_numpy_array(fit_base.aparam_inv_std) + np.testing.assert_almost_equal(aparam_avg_np, weighted_avg) + np.testing.assert_almost_equal(aparam_inv_std_np, 1.0 / weighted_std) + + def test_sharefitting_resume_preserves_stats(self) -> None: + """resume=True in share_params skips stat merging, preserves buffers.""" + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + ntypes = descrpt.get_ntypes() + dim_out = descrpt.get_dim_out() + + fit_base = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_fparam=2, seed=1 + ).to(env.DEVICE) + fit_link = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_fparam=2, seed=2 + ).to(env.DEVICE) + + # give both fittings different stats + data_base = _make_fake_data_pt([10], [5], [0, 10], [2, 0.4]) + data_link = _make_fake_data_pt([100], [2], [100, 0], [0.001, 3]) + fit_base.compute_input_stats(data_base, protection=1e-2) + fit_link.compute_input_stats(data_link, protection=1e-2) + + # record base's fparam_avg BEFORE sharing + orig_avg = fit_base.fparam_avg.clone() + orig_inv_std = fit_base.fparam_inv_std.clone() + + # share_params with resume=True: should NOT re-merge stats + fit_link.share_params(fit_base, shared_level=0, resume=True) + + # base's fparam_avg unchanged (no weighted merging happened) + torch.testing.assert_close(fit_base.fparam_avg, orig_avg) + torch.testing.assert_close(fit_base.fparam_inv_std, orig_inv_std) + + # buffers are shared (same data_ptr) + self.assertEqual(fit_link.fparam_avg.data_ptr(), fit_base.fparam_avg.data_ptr()) + self.assertEqual( + fit_link.fparam_inv_std.data_ptr(), fit_base.fparam_inv_std.data_ptr() + ) + + def test_case_embd_mismatched_dim_raises(self) -> None: + """dim_case_embd must be the same across all models.""" + config = self._make_sharefit_config(dim_case_embd=2) + # Override model_2 to have a different dim_case_embd + config["model"]["model_dict"]["model_2"]["fitting_net"] = deepcopy(_fitting_net) + config["model"]["model_dict"]["model_2"]["fitting_net"]["dim_case_embd"] = 3 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + with self.assertRaises( + ValueError, msg="Should reject mismatched dim_case_embd" + ): + get_trainer(config, shared_links=shared_links) diff --git a/source/tests/pt_expt/test_change_bias.py b/source/tests/pt_expt/test_change_bias.py index 03329642e9..e3749671aa 100644 --- a/source/tests/pt_expt/test_change_bias.py +++ b/source/tests/pt_expt/test_change_bias.py @@ -145,7 +145,7 @@ def setUpClass(cls) -> None: cls.model_path = os.path.join(cls.tmpdir, "model.ckpt.pt") # Record original bias - cls.original_bias = to_numpy(trainer.wrapper.model.get_out_bias()) + cls.original_bias = to_numpy(trainer.wrapper.model["Default"].get_out_bias()) # Pre-freeze shared .pte and .pt2 files so individual tests don't # each pay the AOTInductor compilation cost (~82s per .pt2). diff --git a/source/tests/pt_expt/test_finetune.py b/source/tests/pt_expt/test_finetune.py index 063bb85f71..b000c313a3 100644 --- a/source/tests/pt_expt/test_finetune.py +++ b/source/tests/pt_expt/test_finetune.py @@ -371,8 +371,10 @@ def test_finetune_change_type(self) -> None: wrapper_new = ModelWrapper(model_new) _, has_new_type = get_index_between_two_maps(old_type_map, new_type_map) - model_with_new_type_stat = wrapper_new.model if has_new_type else None - pretrained_wrapper.model.change_type_map( + model_with_new_type_stat = ( + wrapper_new.model["Default"] if has_new_type else None + ) + pretrained_wrapper.model["Default"].change_type_map( new_type_map, model_with_new_type_stat=model_with_new_type_stat, ) diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py new file mode 100644 index 0000000000..e5c6955ac0 --- /dev/null +++ b/source/tests/pt_expt/test_multitask.py @@ -0,0 +1,2299 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for multi-task training in the pt_expt backend. + +Verifies that: +1. Multi-task training completes without error for various descriptors +2. Shared descriptor parameters are identical between tasks +3. lcurve.out has per-model columns +4. Checkpoint save/load roundtrip works +5. Multi-task freeze extracts single head correctly +6. Shared fitting_net with case_embd works (share_fitting) +7. Shared fitting stat (fparam_avg/fparam_inv_std) are shared between models +8. Case embedding with 3 models and dim_case_embd=3 works correctly +9. Multi-task descriptor gradients match sum of single-task gradients +""" + +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from unittest.mock import ( + patch, +) + +import numpy as np +import torch + +import deepmd.utils.random as dp_random +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt_expt.model import ( + get_model, +) +from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.pt_expt.utils.stat import ( + make_stat_input, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, + process_systems, +) + +_energy_data_requirement = [ + DataRequirementItem("energy", ndof=1, atomic=False, must=False, high_prec=True), + DataRequirementItem("force", ndof=3, atomic=True, must=False, high_prec=False), + DataRequirementItem("virial", ndof=9, atomic=False, must=False, high_prec=False), +] + +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent / "pt" / "water" / "data" / "data_0") + + +def _skip_if_no_data() -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + + +# --------------------------------------------------------------------------- +# Descriptor configs (small models for fast testing) +# --------------------------------------------------------------------------- +_descriptor_se_e2_a = { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, +} + +_descriptor_dpa1 = { + "type": "se_atten", + "sel": 18, + "rcut_smth": 0.5, + "rcut": 3.0, + "neuron": [8, 16], + "axis_neuron": 4, + "attn": 16, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, +} + +_descriptor_dpa2 = { + "type": "dpa2", + "repinit": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 18, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "activation_function": "tanh", + }, + "repformer": { + "rcut": 3.0, + "rcut_smth": 0.5, + "nsel": 12, + "nlayers": 2, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_neuron": 4, + "update_h2": False, + "update_g1_has_conv": True, + "update_g1_has_grrg": True, + "update_g1_has_drrd": True, + "update_g1_has_attn": True, + "update_g2_has_g1g1": True, + "update_g2_has_attn": True, + "attn2_has_gate": True, + }, + "seed": 1, + "add_tebd_to_repinit_out": False, +} + +_descriptor_dpa3 = { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 5, + "a_dim": 4, + "nlayers": 2, + "e_rcut": 3.0, + "e_rcut_smth": 0.5, + "e_sel": 12, + "a_rcut": 3.0, + "a_rcut_smth": 0.5, + "a_sel": 8, + "axis_neuron": 4, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "smooth_edge_update": True, + }, + "activation_function": "silut:10.0", + "use_tebd_bias": False, + "precision": "float32", + "concat_output_tebd": False, +} + +_fitting_net = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, +} + + +def _make_multitask_config( + descriptor: dict, + data_dir: str = _PT_DATA, + numb_steps: int = 1, + share_fitting: bool = False, +) -> dict: + """Build a multi-task config with the given descriptor.""" + type_map = ["O", "H"] + fitting = deepcopy(_fitting_net) + + shared_dict: dict = { + "my_type_map": type_map, + "my_descriptor": deepcopy(descriptor), + } + + if share_fitting: + shared_fitting = deepcopy(fitting) + shared_fitting["dim_case_embd"] = 2 + shared_dict["my_fitting"] = shared_fitting + fitting_ref_1: dict | str = "my_fitting" + fitting_ref_2: dict | str = "my_fitting" + else: + fitting_ref_1 = deepcopy(fitting) + fitting_ref_2 = deepcopy(fitting) + + config = { + "model": { + "shared_dict": shared_dict, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_ref_1, + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_ref_2, + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + return config + + +class MultiTaskTrainTest: + """Mixin that tests multi-task training for a particular descriptor type. + + Subclasses must set ``self.config``, ``self.shared_links``, + and ``self.share_fitting`` before calling these test methods. + """ + + def test_multitask_train(self) -> None: + """Train, verify lcurve format and shared params.""" + trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) + trainer.run() + + # --- lcurve.out format --- + lcurve_path = "lcurve.out" + self.assertTrue(os.path.exists(lcurve_path), "lcurve.out not created") + with open(lcurve_path) as f: + lines = f.readlines() + header_line = lines[0] + header_cols = header_line.strip().lstrip("#").split() + model_keys = list(self.config["training"]["model_prob"].keys()) + for mk in model_keys: + cols_for_model = [c for c in header_cols if mk in c] + self.assertGreater( + len(cols_for_model), 0, f"No lcurve columns found for {mk}" + ) + data_lines = [line for line in lines if not line.startswith("#")] + self.assertGreater(len(data_lines), 0, "No data lines in lcurve.out") + data_cols = data_lines[0].split() + self.assertEqual(len(data_cols), len(header_cols)) + + # --- model keys --- + self.assertEqual(len(trainer.wrapper.model), 2) + self.assertIn("model_1", trainer.wrapper.model) + self.assertIn("model_2", trainer.wrapper.model) + + # --- shared descriptor params are identical --- + multi_state_dict = trainer.wrapper.model.state_dict() + for state_key in multi_state_dict: + if "model_1" in state_key: + partner_key = state_key.replace("model_1", "model_2") + self.assertIn(partner_key, multi_state_dict) + if "model_2" in state_key: + partner_key = state_key.replace("model_2", "model_1") + self.assertIn(partner_key, multi_state_dict) + + is_descriptor = "model_1.atomic_model.descriptor" in state_key + is_shared_fitting = ( + self.share_fitting + and "model_1.atomic_model.fitting_net" in state_key + and "fitting_net.bias_atom_e" not in state_key + and "fitting_net.case_embd" not in state_key + ) + if is_descriptor or is_shared_fitting: + partner_key = state_key.replace("model_1", "model_2") + torch.testing.assert_close( + multi_state_dict[state_key], + multi_state_dict[partner_key], + msg=f"Shared param mismatch: {state_key}", + ) + + # --- checkpoint exists --- + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + self.assertGreater(len(ckpt_files), 0, "No checkpoint files saved") + + # --- case_embd verification (share_fitting only) --- + # Verify that each branch's case_embd is a distinct one-hot vector + # matching the alphabetical sort order, so the shared fitting net + # can distinguish which training dataset is being used. + if self.share_fitting: + ce1 = trainer.wrapper.model["model_1"].atomic_model.fitting_net.case_embd + ce2 = trainer.wrapper.model["model_2"].atomic_model.fitting_net.case_embd + self.assertIsNotNone(ce1, "case_embd not set on model_1") + self.assertIsNotNone(ce2, "case_embd not set on model_2") + dim = ce1.shape[0] + # Sorted keys: ["model_1", "model_2"] → indices 0, 1 + expected_eye = torch.eye(dim, dtype=ce1.dtype, device=ce1.device) + torch.testing.assert_close( + ce1, + expected_eye[0], + msg="model_1 case_embd should be one-hot index 0 (alphabetical order)", + ) + torch.testing.assert_close( + ce2, + expected_eye[1], + msg="model_2 case_embd should be one-hot index 1 (alphabetical order)", + ) + # case_embd should NOT be shared in state_dict + for state_key in multi_state_dict: + if ( + "model_1.atomic_model.fitting_net" in state_key + and "case_embd" in state_key + ): + partner_key = state_key.replace("model_1", "model_2") + self.assertFalse( + torch.equal( + multi_state_dict[state_key], + multi_state_dict[partner_key], + ), + f"case_embd should NOT be shared: {state_key}", + ) + + def test_multitask_finetune(self) -> None: + """Train, then finetune with 4 branches from pretrained 2-branch model. + + For mixed_types descriptors, uses extended type_map ["O","H","B"] to test + change_type_map + model_with_new_type_stat integration. For non-mixed_types + descriptors, uses same type_map ["O","H"]. + + Builds a reference state_dict by manually replicating the trainer's + finetune operations (load pretrained, change_type_map, weight copy) and + verifies per-branch weight inheritance: + - model_1: resume (ALL weights match reference) + - model_2: finetune from model_2 (all except out_bias/out_std match) + - model_3: finetune from model_2 as new head (cross-branch key remap) + - model_4: random fitting (descriptor from pretrained, random fitting_net) + """ + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + # Phase 1: train pretrained 2-branch model (2 steps) + config_pretrain = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting, numb_steps=2 + ) + config_pretrain["training"]["save_freq"] = 2 + config_pretrain["model"], shared_links_pre = preprocess_shared_params( + config_pretrain["model"] + ) + config_pretrain = update_deepmd_input(config_pretrain, warning=False) + config_pretrain = normalize(config_pretrain, multi_task=True) + trainer = get_trainer(config_pretrain, shared_links=shared_links_pre) + trainer.run() + + ckpt_path = os.path.join(os.getcwd(), "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Pretrained checkpoint not created") + + # Phase 2: build reference state_dict + # For mixed_types: extend type_map to ["O","H","B"], build + # model_with_new_type_stat with computed stats, and apply + # change_type_map on pretrained. + # For non-mixed_types: use pretrained state directly (no extension). + ft_type_map = ["O", "H", "B"] if self.mixed_types else ["O", "H"] + + state_dict_full = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) + state_dict_ckpt = ( + state_dict_full["model"] if "model" in state_dict_full else state_dict_full + ) + pretrained_model_params = state_dict_ckpt["_extra_state"]["model_params"] + + # Build pretrained wrapper (separate model per branch) + pretrained_models = {} + for pk in pretrained_model_params["model_dict"]: + pretrained_models[pk] = get_model( + deepcopy(pretrained_model_params["model_dict"][pk]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_models) + pretrained_wrapper.load_state_dict(state_dict_ckpt) + + # Record pretrained state BEFORE change_type_map — used later to + # verify O/H stats are inherited from pretrained, not recomputed. + pretrained_oh_state = { + k: v.clone() for k, v in pretrained_wrapper.model.state_dict().items() + } + + if self.mixed_types: + # Build a model with extended type_map and compute stats so that + # the new type ("B", unseen in data) gets proper default stats + # (davg=0, dstd=0.1) instead of the no-stat defaults (0/1). + ref_model_params = deepcopy( + pretrained_model_params["model_dict"]["model_1"] + ) + ref_model_params["type_map"] = ft_type_map + ref_model = get_model(ref_model_params).to(DEVICE) + + data_systems = process_systems([_PT_DATA]) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=ft_type_map, + trn_all_set=True, + ) + data.add_data_requirements(_energy_data_requirement) + ref_model.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + + # Apply change_type_map on each pretrained branch + for pk in pretrained_model_params["model_dict"]: + pretrained_wrapper.model[pk].change_type_map( + ft_type_map, + model_with_new_type_stat=ref_model, + ) + + ref_state_dict = pretrained_wrapper.model.state_dict() + + # Phase 3: build 4-branch finetune config + finetune_config = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting + ) + if self.mixed_types: + finetune_config["model"]["shared_dict"]["my_type_map"] = ft_type_map + + # Add model_3 and model_4 (copies of model_2) + finetune_config["model"]["model_dict"]["model_3"] = deepcopy( + finetune_config["model"]["model_dict"]["model_2"] + ) + finetune_config["model"]["model_dict"]["model_4"] = deepcopy( + finetune_config["model"]["model_dict"]["model_2"] + ) + finetune_config["loss_dict"]["model_3"] = deepcopy( + finetune_config["loss_dict"]["model_2"] + ) + finetune_config["loss_dict"]["model_4"] = deepcopy( + finetune_config["loss_dict"]["model_2"] + ) + finetune_config["training"]["model_prob"]["model_3"] = deepcopy( + finetune_config["training"]["model_prob"]["model_2"] + ) + finetune_config["training"]["model_prob"]["model_4"] = deepcopy( + finetune_config["training"]["model_prob"]["model_2"] + ) + finetune_config["training"]["data_dict"]["model_3"] = deepcopy( + finetune_config["training"]["data_dict"]["model_2"] + ) + finetune_config["training"]["data_dict"]["model_3"]["stat_file"] = ( + finetune_config["training"]["data_dict"]["model_3"]["stat_file"].replace( + "model_2", "model_3" + ) + ) + finetune_config["training"]["data_dict"]["model_4"] = deepcopy( + finetune_config["training"]["data_dict"]["model_2"] + ) + finetune_config["training"]["data_dict"]["model_4"]["stat_file"] = ( + finetune_config["training"]["data_dict"]["model_4"]["stat_file"].replace( + "model_2", "model_4" + ) + ) + + # Set finetune rules: + # model_1: no finetune_head → resume from model_1 (resuming=True) + # model_2: finetune_head="model_2" → finetune from model_2 + finetune_config["model"]["model_dict"]["model_2"]["finetune_head"] = "model_2" + # model_3: finetune_head="model_2" → finetune from model_2 (new head) + finetune_config["model"]["model_dict"]["model_3"]["finetune_head"] = "model_2" + # model_4: no finetune_head, new name → random fitting + + finetune_config["training"]["numb_steps"] = 1 + finetune_config["training"]["save_freq"] = 1 + + finetune_config["model"], shared_links_ft = preprocess_shared_params( + finetune_config["model"] + ) + finetune_config["model"], finetune_links = get_finetune_rules( + ckpt_path, finetune_config["model"] + ) + finetune_config = update_deepmd_input(finetune_config, warning=False) + finetune_config = normalize(finetune_config, multi_task=True) + + trainer_ft = get_trainer( + deepcopy(finetune_config), + finetune_model=ckpt_path, + shared_links=shared_links_ft, + finetune_links=finetune_links, + ) + + # Phase 4: verify weight inheritance against reference + ft_state_dict = trainer_ft.wrapper.model.state_dict() + + # When type_map is extended, type_embedding weights for the new type + # are randomly initialized (np.random.default_rng) during + # change_type_map; since reference and trainer build separate + # pretrained wrappers, these random values differ — skip them. + _skip_type_embed = self.mixed_types + + for state_key in ft_state_dict: + if _skip_type_embed and "type_embedding" in state_key: + continue + if "model_1" in state_key: + # model_1: resume — ALL weights match reference model_1 + torch.testing.assert_close( + ref_state_dict[state_key], + ft_state_dict[state_key], + msg=f"model_1 (resume) weight mismatch: {state_key}", + ) + elif ( + "model_2" in state_key + and "out_bias" not in state_key + and "out_std" not in state_key + ): + # model_2: finetune — all except out_bias/out_std + torch.testing.assert_close( + ref_state_dict[state_key], + ft_state_dict[state_key], + msg=f"model_2 (finetune) weight mismatch: {state_key}", + ) + elif ( + "model_3" in state_key + and "out_bias" not in state_key + and "out_std" not in state_key + ): + # model_3: finetune from model_2 — cross-branch key remap + ref_key = state_key.replace("model_3", "model_2") + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_3 (finetune from model_2) weight mismatch: {state_key}", + ) + elif ( + "model_4" in state_key + and "fitting_net" not in state_key + and "out_bias" not in state_key + and "out_std" not in state_key + ): + # model_4: random fitting — descriptor from pretrained + # (RANDOM + from_multitask uses first pretrained key = model_1; + # since descriptors are shared, model_1 == model_2 in pretrained) + ref_key = state_key.replace("model_4", "model_2") + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_4 (random fitting) descriptor mismatch: {state_key}", + ) + + # Phase 5: verify O/H descriptor stats are inherited from pretrained + # (not recomputed from finetune data). + # For mixed_types: pretrained has shape [2,...] (O,H); finetuned has + # shape [3,...] (O,H,B). The first 2 entries must match pretrained. + # For non-mixed_types: shapes are identical, already fully checked above. + _STAT_SUFFIXES = ("mean", "stddev", "davg", "dstd") + if self.mixed_types: + n_old = len(["O", "H"]) + n_new = len(ft_type_map) + checked_count = 0 + for key in ft_state_dict: + if not any(key.endswith(s) for s in _STAT_SUFFIXES): + continue + # Use model_1 (all branches share descriptor after share_params) + if "model_1" not in key: + continue + pre_key = key # same key in pretrained_oh_state + if pre_key not in pretrained_oh_state: + continue + pre_val = pretrained_oh_state[pre_key] + ft_val = ft_state_dict[key] + # Find the type axis (size grew from n_old to n_new) + for ax in range(pre_val.ndim): + if pre_val.shape[ax] == n_old and ft_val.shape[ax] == n_new: + for ti, tname in enumerate(["O", "H"]): + torch.testing.assert_close( + ft_val.select(ax, ti), + pre_val.select(ax, ti), + msg=( + f"{tname} stat not inherited from pretrained: {key}" + ), + ) + checked_count += 1 + break + self.assertGreater( + checked_count, + 0, + "No descriptor stat keys found for O/H inheritance check", + ) + + # Phase 6: verify case_embd inheritance (share_fitting only) + # Pretrained branches keep their case_embd (dataset correspondence). + # New branches (model_3 finetune from model_2, model_4 random) get + # case_embd from the weight copy: model_3 copies model_2's, model_4 + # keeps target default (zeros since set_case_embd is skipped on finetune). + if self.share_fitting: + + def _get_case_embd(mk): + return trainer_ft.wrapper.model[mk].atomic_model.fitting_net.case_embd + + ce1 = _get_case_embd("model_1") + ce2 = _get_case_embd("model_2") + ce3 = _get_case_embd("model_3") + ce4 = _get_case_embd("model_4") + # Pretrained had sorted keys ["model_1","model_2"] → one-hot [1,0], [0,1] + dim = ce1.shape[0] + expected_eye = torch.eye(dim, dtype=ce1.dtype, device=ce1.device) + # model_1 (resume): inherits pretrained model_1's case_embd + torch.testing.assert_close( + ce1, + expected_eye[0], + msg="model_1 case_embd should match pretrained model_1", + ) + # model_2 (finetune from model_2): inherits pretrained model_2's case_embd + torch.testing.assert_close( + ce2, + expected_eye[1], + msg="model_2 case_embd should match pretrained model_2", + ) + # model_3 (finetune from model_2): weight copy from model_2 + torch.testing.assert_close( + ce3, + expected_eye[1], + msg="model_3 case_embd should match pretrained model_2 (finetune source)", + ) + # model_4 (random fitting): target default (zeros, set_case_embd skipped) + torch.testing.assert_close( + ce4, + torch.zeros_like(ce4), + msg="model_4 case_embd should be zeros (random fitting, no re-init on finetune)", + ) + + # Run 1 step to verify no crash + trainer_ft.run() + + def test_multitask_finetune_from_single_task(self) -> None: + """Finetune multi-task model from a single-task pretrained .pt checkpoint. + + Tests the single-task pretrained → multi-task finetune path + (finetune_from_multi_task=False, training.py:714-721). + + model_1: finetune_head="Default" → copies from single-task pretrained + model_2: no finetune_head, not in pretrained_keys=["Default"] → RANDOM fitting + """ + if self.share_fitting: + # Single-task pretrained has no dim_case_embd; incompatible with + # shared fitting multi-task target. + return + + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + # Phase 1: train single-task model (2 steps) + single_config = { + "model": { + "type_map": ["O", "H"], + "descriptor": deepcopy(self.descriptor), + "fitting_net": deepcopy(_fitting_net), + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": {"systems": [_PT_DATA], "batch_size": 1}, + "validation_data": { + "systems": [_PT_DATA], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 2, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 2, + }, + } + single_config = update_deepmd_input(single_config, warning=False) + single_config = normalize(single_config, multi_task=False) + trainer_st = get_trainer(single_config) + trainer_st.run() + + ckpt_path = os.path.join(os.getcwd(), "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Single-task checkpoint not created") + + # Phase 2: build reference state_dict from single-task checkpoint + state_dict_full = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) + state_dict_ckpt = ( + state_dict_full["model"] if "model" in state_dict_full else state_dict_full + ) + pretrained_model_params = state_dict_ckpt["_extra_state"]["model_params"] + + # Single-task pretrained → wrap as {"Default": model} + ref_model = get_model(deepcopy(pretrained_model_params)).to(DEVICE) + pretrained_wrapper = ModelWrapper(ref_model) + pretrained_wrapper.load_state_dict(state_dict_ckpt) + ref_state_dict = pretrained_wrapper.model.state_dict() + + # Phase 3: build 2-branch multi-task finetune config + finetune_config = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting + ) + # model_1: finetune_head="Default" → copy from single-task + finetune_config["model"]["model_dict"]["model_1"]["finetune_head"] = "Default" + # model_2: no finetune_head, "model_2" not in pretrained_keys=["Default"] → RANDOM + finetune_config["training"]["numb_steps"] = 1 + finetune_config["training"]["save_freq"] = 1 + + finetune_config["model"], shared_links_ft = preprocess_shared_params( + finetune_config["model"] + ) + finetune_config["model"], finetune_links = get_finetune_rules( + ckpt_path, finetune_config["model"] + ) + finetune_config = update_deepmd_input(finetune_config, warning=False) + finetune_config = normalize(finetune_config, multi_task=True) + + trainer_ft = get_trainer( + deepcopy(finetune_config), + finetune_model=ckpt_path, + shared_links=shared_links_ft, + finetune_links=finetune_links, + ) + + # Phase 4: verify weight inheritance + ft_state_dict = trainer_ft.wrapper.model.state_dict() + + for state_key in ft_state_dict: + if "model_1" in state_key: + # model_1: finetune from "Default" — all except out_bias/out_std + if "out_bias" in state_key or "out_std" in state_key: + continue + ref_key = state_key.replace("model_1", "Default") + self.assertIn(ref_key, ref_state_dict, f"Missing ref key: {ref_key}") + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_1 (from Default) weight mismatch: {state_key}", + ) + elif "model_2" in state_key: + if "out_bias" in state_key or "out_std" in state_key: + continue + ref_key = state_key.replace("model_2", "Default") + if ".descriptor." in state_key: + # Descriptor from pretrained (RANDOM uses first pretrained key) + self.assertIn( + ref_key, ref_state_dict, f"Missing ref key: {ref_key}" + ) + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_2 (RANDOM) descriptor mismatch: {state_key}", + ) + + # model_2 fitting NN weights (networks.*) should differ (random init) + fitting_nn_mismatch = 0 + for state_key in ft_state_dict: + if ( + "model_2" in state_key + and ".fitting_net." in state_key + and "networks" in state_key + ): + ref_key = state_key.replace("model_2", "Default") + if ref_key in ref_state_dict and not torch.equal( + ref_state_dict[ref_key], ft_state_dict[state_key] + ): + fitting_nn_mismatch += 1 + self.assertGreater( + fitting_nn_mismatch, + 0, + "model_2 fitting NN weights should differ from pretrained (random init)", + ) + + # Phase 5: run 1 step to verify no crash + trainer_ft.run() + + def test_multitask_finetune_no_change_model_params(self) -> None: + """Test change_model_params=False preserves user config in multi-task finetune. + + Contrasts with change_model_params=True which overwrites descriptor/fitting + from pretrained (preserving trainable flags). + """ + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + # Phase 1: train 2-branch multi-task model (2 steps) + config_pretrain = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting, numb_steps=2 + ) + config_pretrain["training"]["save_freq"] = 2 + config_pretrain["model"], shared_links_pre = preprocess_shared_params( + config_pretrain["model"] + ) + config_pretrain = update_deepmd_input(config_pretrain, warning=False) + config_pretrain = normalize(config_pretrain, multi_task=True) + trainer = get_trainer(config_pretrain, shared_links=shared_links_pre) + trainer.run() + + ckpt_path = os.path.join(os.getcwd(), "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Pretrained checkpoint not created") + + # Phase 2: build finetune config with marker in descriptor + ft_config = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting + ) + # model_1: no finetune_head → resume (model_1 in pretrained_keys) + # model_2: finetune_head="model_2" → finetune + ft_config["model"]["model_dict"]["model_2"]["finetune_head"] = "model_2" + ft_config["training"]["numb_steps"] = 1 + ft_config["training"]["save_freq"] = 1 + + # Add markers to descriptor in each branch (before preprocess_shared_params + # resolves shared_dict references) + ft_config["model"]["shared_dict"]["my_descriptor"]["_test_marker"] = True + + # Phase 3: test change_model_params=False + ft_config_false = deepcopy(ft_config) + ft_config_false["model"], _ = preprocess_shared_params(ft_config_false["model"]) + model_config_false, finetune_links_false = get_finetune_rules( + ckpt_path, deepcopy(ft_config_false["model"]), change_model_params=False + ) + + # User config preserved: marker still present + self.assertTrue( + model_config_false["model_dict"]["model_1"]["descriptor"].get( + "_test_marker", False + ), + "model_1 descriptor should preserve _test_marker with change_model_params=False", + ) + self.assertTrue( + model_config_false["model_dict"]["model_2"]["descriptor"].get( + "_test_marker", False + ), + "model_2 descriptor should preserve _test_marker with change_model_params=False", + ) + # FinetuneRuleItem has correct type_map + for mk in ("model_1", "model_2"): + self.assertEqual( + finetune_links_false[mk].get_finetune_tmap(), + ["O", "H"], + f"{mk} finetune tmap should be ['O','H']", + ) + # model_1 is resuming, model_2 is not + self.assertTrue( + finetune_links_false["model_1"].resuming, + "model_1 should be resuming (no finetune_head, name in pretrained_keys)", + ) + self.assertFalse( + finetune_links_false["model_2"].resuming, + "model_2 should not be resuming (has finetune_head)", + ) + + # Phase 4: test change_model_params=True (contrast) + ft_config_true = deepcopy(ft_config) + # Also set trainable=False to verify it's preserved + ft_config_true["model"]["shared_dict"]["my_descriptor"]["trainable"] = False + ft_config_true["model"], _ = preprocess_shared_params(ft_config_true["model"]) + model_config_true, _finetune_links_true = get_finetune_rules( + ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True + ) + + # Marker overwritten from pretrained + self.assertFalse( + model_config_true["model_dict"]["model_1"]["descriptor"].get( + "_test_marker", False + ), + "model_1 descriptor should NOT have _test_marker with change_model_params=True", + ) + self.assertFalse( + model_config_true["model_dict"]["model_2"]["descriptor"].get( + "_test_marker", False + ), + "model_2 descriptor should NOT have _test_marker with change_model_params=True", + ) + # trainable=False should be preserved + self.assertFalse( + model_config_true["model_dict"]["model_1"]["descriptor"].get( + "trainable", True + ), + "model_1 descriptor trainable should be preserved as False", + ) + self.assertFalse( + model_config_true["model_dict"]["model_2"]["descriptor"].get( + "trainable", True + ), + "model_2 descriptor trainable should be preserved as False", + ) + + # Phase 5: build trainer with change_model_params=False → run 1 step + ft_config_run = deepcopy(ft_config) + ft_config_run["model"], shared_links_ft = preprocess_shared_params( + ft_config_run["model"] + ) + ft_config_run["model"], finetune_links_run = get_finetune_rules( + ckpt_path, ft_config_run["model"], change_model_params=False + ) + ft_config_run = update_deepmd_input(ft_config_run, warning=False) + ft_config_run = normalize(ft_config_run, multi_task=True) + trainer_ft = get_trainer( + deepcopy(ft_config_run), + finetune_model=ckpt_path, + shared_links=shared_links_ft, + finetune_links=finetune_links_run, + ) + trainer_ft.run() + + def test_change_type_map_stat(self) -> None: + """Validate change_type_map preserves existing types' stats. + + Tests two modes: + 1. WITHOUT model_with_new_type_stat: existing types preserved, + new type gets default values (zeros for davg/bias, ones for dstd/std). + 2. WITH model_with_new_type_stat: existing types preserved, + new type gets data-computed values (davg=0, dstd=0.1 for zero + observations via StatItem defaults). + """ + if not self.mixed_types: + return + + old_tmap = ["O", "H"] + new_tmap = ["O", "H", "B"] + + model_config = deepcopy(self.config["model"]["model_dict"]["model_1"]) + + # Build model with old type_map and compute stats + model = get_model(deepcopy(model_config)).to(DEVICE) + data_systems = process_systems([_PT_DATA]) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=old_tmap, + trn_all_set=True, + ) + data.add_data_requirements(_energy_data_requirement) + model.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + sd_before = {k: v.clone() for k, v in model.state_dict().items()} + + # ---- Test 1: change_type_map WITHOUT model_with_new_type_stat ---- + model.change_type_map(new_tmap, model_with_new_type_stat=None) + sd_no_stat = model.state_dict() + + # Stat-like keys: descriptor mean/stddev/davg/dstd and atomic out_bias/out_std + _STAT_SUFFIXES = ("mean", "stddev", "davg", "dstd", "out_bias", "out_std") + + def _is_stat_key(k: str) -> bool: + return any(k.endswith(s) for s in _STAT_SUFFIXES) + + def _is_std_like(k: str) -> bool: + return k.endswith(("stddev", "dstd", "out_std")) + + for key in sd_no_stat: + if key not in sd_before or not _is_stat_key(key): + continue + old_val = sd_before[key] + new_val = sd_no_stat[key] + if old_val.shape == new_val.shape: + continue + # Find the type axis: size went from len(old_tmap) to len(new_tmap) + for ax in range(old_val.ndim): + if old_val.shape[ax] == len(old_tmap) and new_val.shape[ax] == len( + new_tmap + ): + # Existing types preserved + torch.testing.assert_close( + new_val.select(ax, 0), + old_val.select(ax, 0), + msg=f"O stat changed (no model_with_new_type_stat): {key}", + ) + torch.testing.assert_close( + new_val.select(ax, 1), + old_val.select(ax, 1), + msg=f"H stat changed (no model_with_new_type_stat): {key}", + ) + # New type B: defaults (zeros for mean/davg/bias, ones for std) + new_B = new_val.select(ax, 2) + if _is_std_like(key): + torch.testing.assert_close( + new_B, + torch.ones_like(new_B), + msg=f"B default should be ones: {key}", + ) + else: + torch.testing.assert_close( + new_B, + torch.zeros_like(new_B), + msg=f"B default should be zeros: {key}", + ) + break + + # ---- Test 2: change_type_map WITH model_with_new_type_stat ---- + # Build fresh model with old type_map + model2 = get_model(deepcopy(model_config)).to(DEVICE) + model2.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + sd_before2 = {k: v.clone() for k, v in model2.state_dict().items()} + + # Build model_with_new_type_stat with extended type_map + model_ext_config = deepcopy(model_config) + model_ext_config["type_map"] = new_tmap + model_ext = get_model(model_ext_config).to(DEVICE) + data_ext = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=new_tmap, + trn_all_set=True, + ) + data_ext.add_data_requirements(_energy_data_requirement) + model_ext.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data_ext, 1), + stat_file_path=None, + ) + + model2.change_type_map(new_tmap, model_with_new_type_stat=model_ext) + sd_with_stat = model2.state_dict() + + for key in sd_with_stat: + if key not in sd_before2 or not _is_stat_key(key): + continue + old_val = sd_before2[key] + new_val = sd_with_stat[key] + if old_val.shape == new_val.shape: + continue + for ax in range(old_val.ndim): + if old_val.shape[ax] == len(old_tmap) and new_val.shape[ax] == len( + new_tmap + ): + # Existing types preserved + torch.testing.assert_close( + new_val.select(ax, 0), + old_val.select(ax, 0), + msg=f"O stat changed (with model_with_new_type_stat): {key}", + ) + torch.testing.assert_close( + new_val.select(ax, 1), + old_val.select(ax, 1), + msg=f"H stat changed (with model_with_new_type_stat): {key}", + ) + # New type B: descriptor stats should use model_ext's + # computed values, NOT the no-stat defaults (ones) + new_B = new_val.select(ax, 2) + is_descrpt_std = key.endswith(("stddev", "dstd")) + if is_descrpt_std: + # B has zero observations → StatItem default = 0.1 + # (not ones like the no-stat default) + self.assertFalse( + torch.allclose(new_B, torch.ones_like(new_B)), + f"B descriptor stat should NOT be ones " + f"(should be 0.1 from StatItem default): {key}", + ) + break + + def test_multitask_restart(self) -> None: + """Train, then restart from checkpoint and verify.""" + # Phase 1: train + config1 = deepcopy(self.config) + config1["training"]["numb_steps"] = 2 + config1["training"]["save_freq"] = 2 + trainer1 = get_trainer(config1, shared_links=self.shared_links) + trainer1.run() + + ckpt_path = "model.ckpt.pt" + self.assertTrue(os.path.exists(ckpt_path), "Checkpoint not created") + + # Phase 2: restart to step 4 + config2 = deepcopy(self.config) + config2["training"]["numb_steps"] = 4 + config2["training"]["save_freq"] = 4 + trainer2 = get_trainer( + config2, + restart_model=ckpt_path, + shared_links=self.shared_links, + ) + self.assertEqual(trainer2.start_step, 2) + trainer2.run() + + def test_multitask_freeze(self) -> None: + """Train, then freeze with --head and verify. + + Only runs for dpa3 descriptor to avoid redundant slow freeze tests. + """ + if self.descriptor.get("type") != "dpa3": + return + + from deepmd.pt_expt.entrypoints.main import ( + freeze, + ) + + # Train + config = deepcopy(self.config) + trainer = get_trainer(config, shared_links=self.shared_links) + trainer.run() + + # Freeze head model_1 + ckpt_path = "model.ckpt.pt" + output_path = "frozen_model_1.pte" + freeze(model=ckpt_path, output=output_path, head="model_1") + self.assertTrue(os.path.exists(output_path), "Frozen model not created") + + # Verify frozen model loads + from deepmd.pt_expt.model import ( + BaseModel, + ) + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(output_path) + self.assertIn("model", data) + frozen_model = BaseModel.deserialize(data["model"]) + self.assertIsInstance(frozen_model, torch.nn.Module) + + def test_multitask_freeze_no_head_raises(self) -> None: + """Freezing multi-task model without --head raises ValueError. + + Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + """ + if self.descriptor.get("type") != "dpa3": + return + + from deepmd.pt_expt.entrypoints.main import ( + freeze, + ) + + config = deepcopy(self.config) + trainer = get_trainer(config, shared_links=self.shared_links) + trainer.run() + + ckpt_path = "model.ckpt.pt" + with self.assertRaises(ValueError, msg="Should require --head"): + freeze(model=ckpt_path, output="frozen.pte", head=None) + + def test_multitask_freeze_invalid_head_raises(self) -> None: + """Freezing multi-task model with invalid --head raises ValueError. + + Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + """ + if self.descriptor.get("type") != "dpa3": + return + + from deepmd.pt_expt.entrypoints.main import ( + freeze, + ) + + config = deepcopy(self.config) + trainer = get_trainer(config, shared_links=self.shared_links) + trainer.run() + + ckpt_path = "model.ckpt.pt" + with self.assertRaises(ValueError, msg="Should reject invalid head"): + freeze(model=ckpt_path, output="frozen.pte", head="nonexistent") + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f == "lcurve.out": + os.remove(f) + if f.endswith(".pte"): + os.remove(f) + if os.path.isdir("stat_files"): + shutil.rmtree("stat_files") + + +class TestMultiTaskSeA(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with se_e2_a descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_sea_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_se_e2_a + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = False + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskSeAShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with se_e2_a descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_sea_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_se_e2_a + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = False + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA1 (se_atten) descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa1_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa1 + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA1ShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA1 descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa1_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa1 + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA2(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA2 descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa2_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa2 + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA2ShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA2 descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa2_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa2 + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA3(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA3 descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa3_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa3 + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA3ShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA3 descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa3_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa3 + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskCompile(unittest.TestCase): + """Verify that multi-task + torch.compile works correctly.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def _check_compile_correctness(self, share_fitting: bool = False) -> None: + """Compiled multi-task model predictions and gradients match uncompiled. + + For each branch: feed the same batch through wrapper (which computes + loss), call loss.backward(), then compare: + 1. model predictions (atom_energy, energy, force, virial) + 2. loss values + 3. parameter gradients (second-order, through force loss) + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + # Build uncompiled trainer + config_uc = _make_multitask_config( + _descriptor_se_e2_a, share_fitting=share_fitting + ) + config_uc["model"], shared_links_uc = preprocess_shared_params( + config_uc["model"] + ) + config_uc = update_deepmd_input(config_uc, warning=False) + config_uc = normalize(config_uc, multi_task=True) + + # Build compiled trainer + config_c = _make_multitask_config( + _descriptor_se_e2_a, share_fitting=share_fitting + ) + config_c["training"]["enable_compile"] = True + config_c["model"], shared_links_c = preprocess_shared_params(config_c["model"]) + config_c = update_deepmd_input(config_c, warning=False) + config_c = normalize(config_c, multi_task=True) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_compile_corr_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer_uc = get_trainer(config_uc, shared_links=shared_links_uc) + trainer_c = get_trainer(config_c, shared_links=shared_links_c) + for mk in ("model_1", "model_2"): + self.assertIsInstance(trainer_c.wrapper.model[mk], _CompiledModel) + + # Copy uncompiled weights → compiled (same starting point) + for mk in ("model_1", "model_2"): + trainer_c.wrapper.model[mk].original_model.load_state_dict( + trainer_uc.wrapper.model[mk].state_dict() + ) + + # For each branch, run one forward+backward and compare + for task_key in ("model_1", "model_2"): + trainer_uc.optimizer.zero_grad(set_to_none=True) + trainer_c.optimizer.zero_grad(set_to_none=True) + + input_dict, label_dict = trainer_uc.get_data( + is_train=True, task_key=task_key + ) + + cur_lr = trainer_uc.scheduler.get_last_lr()[0] + pred_uc, loss_uc, _ = trainer_uc.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=task_key, + ) + pred_c, loss_c, _ = trainer_c.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=task_key, + ) + + # Compare predictions: atom_energy, energy, force, virial. + # Atomic virial is not exercised because training does not + # pass ``do_atomic_virial=True``; the compiled graph is + # traced with the default (False) so per-atom virial is not + # computed by the compiled path. + for key in ("atom_energy", "energy", "force", "virial"): + self.assertIn( + key, pred_uc, f"uncompiled missing '{key}' (task={task_key})" + ) + self.assertIn( + key, pred_c, f"compiled missing '{key}' (task={task_key})" + ) + torch.testing.assert_close( + pred_c[key], + pred_uc[key], + atol=1e-10, + rtol=1e-10, + msg=f"{key} mismatch (task={task_key})", + ) + torch.testing.assert_close(loss_c, loss_uc, atol=1e-10, rtol=1e-10) + + # Compare gradients (second-order, through force loss) + loss_uc.backward() + loss_c.backward() + for (name_uc, p_uc), (name_c, p_c) in zip( + trainer_uc.wrapper.model[task_key].named_parameters(), + trainer_c.wrapper.model[task_key].original_model.named_parameters(), + strict=True, + ): + if p_uc.grad is not None: + self.assertIsNotNone( + p_c.grad, + msg=f"grad is None for {name_c} (task={task_key})", + ) + torch.testing.assert_close( + p_c.grad, + p_uc.grad, + atol=1e-10, + rtol=1e-10, + msg=f"grad mismatch on {name_uc} (task={task_key})", + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compile_multitask_correctness(self) -> None: + """Compiled multi-task predictions and gradients match uncompiled.""" + self._check_compile_correctness(share_fitting=False) + + def test_compile_multitask_correctness_share_fitting(self) -> None: + """Compiled multi-task with shared fitting: predictions and gradients match.""" + self._check_compile_correctness(share_fitting=True) + + def test_compile_multitask_train(self) -> None: + """Train multi-task model with torch.compile for a few steps.""" + config = _make_multitask_config(_descriptor_se_e2_a) + config["training"]["enable_compile"] = True + config["training"]["numb_steps"] = 2 + config["training"]["save_freq"] = 2 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_compile_train_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compile_multitask_train_share_fitting(self) -> None: + """Train multi-task model with shared fitting + compile for a few steps.""" + config = _make_multitask_config(_descriptor_se_e2_a, share_fitting=True) + config["training"]["enable_compile"] = True + config["training"]["numb_steps"] = 2 + config["training"]["save_freq"] = 2 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_compile_sf_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Gradient accumulation test helpers +# --------------------------------------------------------------------------- + + +def _generate_random_data_dir( + path: str, + atom_types: list[int], + nframes: int, + seed: int, + nfparam: int = 0, + naparam: int = 0, +) -> None: + """Create a minimal deepmd data directory with random data.""" + rng = np.random.RandomState(seed) + natoms = len(atom_types) + os.makedirs(os.path.join(path, "set.000"), exist_ok=True) + + # type.raw + with open(os.path.join(path, "type.raw"), "w") as f: + for t in atom_types: + f.write(f"{t}\n") + + # box: diagonal 20x20x20 + box = np.tile(np.diag([20.0, 20.0, 20.0]).flatten(), (nframes, 1)) + np.save(os.path.join(path, "set.000", "box.npy"), box) + + # coord + coord = rng.random((nframes, natoms * 3)) * 20.0 + np.save(os.path.join(path, "set.000", "coord.npy"), coord) + + # energy + energy = rng.random((nframes,)) + np.save(os.path.join(path, "set.000", "energy.npy"), energy) + + # force + force = rng.random((nframes, natoms * 3)) + np.save(os.path.join(path, "set.000", "force.npy"), force) + + # fparam (frame parameters) + if nfparam > 0: + fparam = rng.random((nframes, nfparam)) + np.save(os.path.join(path, "set.000", "fparam.npy"), fparam) + + # aparam (atomic parameters) + if naparam > 0: + aparam = rng.random((nframes, natoms * naparam)) + np.save(os.path.join(path, "set.000", "aparam.npy"), aparam) + + +def _make_gradient_test_mt_config( + data_dir_0: str, + data_dir_1: str, + numb_fparam: int = 0, + numb_aparam: int = 0, +) -> dict: + """Multi-task config for gradient accumulation test.""" + type_map = ["O", "H", "C"] + descriptor = deepcopy(_descriptor_dpa3) + fitting_1: dict = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + } + fitting_2: dict = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 2, + } + if numb_fparam > 0: + fitting_1["numb_fparam"] = numb_fparam + fitting_2["numb_fparam"] = numb_fparam + if numb_aparam > 0: + fitting_1["numb_aparam"] = numb_aparam + fitting_2["numb_aparam"] = numb_aparam + return { + "model": { + "shared_dict": { + "my_type_map": type_map, + "my_descriptor": descriptor, + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_1, + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_2, + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 0.001, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir_0], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir_0], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir_1], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir_1], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": 2, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + + +def _make_gradient_test_st_config( + data_dir: str, + fitting_seed: int, + numb_fparam: int = 0, + numb_aparam: int = 0, +) -> dict: + """Single-task config for gradient accumulation test.""" + type_map = ["O", "H", "C"] + descriptor = deepcopy(_descriptor_dpa3) + fitting: dict = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": fitting_seed, + } + if numb_fparam > 0: + fitting["numb_fparam"] = numb_fparam + if numb_aparam > 0: + fitting["numb_aparam"] = numb_aparam + return { + "model": { + "type_map": type_map, + "descriptor": descriptor, + "fitting_net": fitting, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 0.001, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + + +def _deterministic_task_choice(task_sequence: list[int]): + """Return a patched dp_random.choice that forces task selection order.""" + original = dp_random.choice + it = iter(task_sequence) + + def patched(a, size=None, replace=True, p=None): + # Task selection: array with >=2 elements and probability vector + if hasattr(a, "__len__") and len(a) >= 2 and p is not None: + return next(it) + return original(a, size=size, replace=replace, p=p) + + return patched + + +def _make_recording_step( + trainer, + modules_to_record: dict, + recorded_grads: list[dict], +): + """Patch _optimizer_step: record grads from listed modules, skip optimizer. + + Parameters + ---------- + trainer : Trainer + The trainer whose scheduler.step() is called. + modules_to_record : dict[str, torch.nn.Module] + Named modules whose parameter gradients to record. + recorded_grads : list[dict[str, torch.Tensor]] + Appended with {module_key/param_name: grad} at each step. + """ + + def recording_step(): + grads = {} + for mod_key, mod in modules_to_record.items(): + for n, p in mod.named_parameters(): + if p.grad is not None: + grads[f"{mod_key}/{n}"] = p.grad.clone() + recorded_grads.append(grads) + trainer.scheduler.step() + + return recording_step + + +class TestMultiTaskGradient(unittest.TestCase): + """Verify multi-task descriptor gradients match sum of single-task gradients.""" + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_grad_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.nfparam = 2 + self.naparam = 3 + + self.data_dir_0 = os.path.join(self.tmpdir, "data_task0") + _generate_random_data_dir( + self.data_dir_0, + atom_types=[0, 0, 1, 1, 1, 2], + nframes=1, + seed=42, + nfparam=self.nfparam, + naparam=self.naparam, + ) + self.data_dir_1 = os.path.join(self.tmpdir, "data_task1") + _generate_random_data_dir( + self.data_dir_1, + atom_types=[0, 1, 1, 2, 2, 2, 2], + nframes=1, + seed=137, + nfparam=self.nfparam, + naparam=self.naparam, + ) + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_gradient_accumulation(self) -> None: + """Sum of per-task grads from multi-task run() == + sum of grads from two single-task run() calls. + """ + # ===== Multi-task trainer ===== + mt_config = _make_gradient_test_mt_config( + self.data_dir_0, + self.data_dir_1, + numb_fparam=self.nfparam, + numb_aparam=self.naparam, + ) + mt_config["model"], shared_links = preprocess_shared_params(mt_config["model"]) + mt_config = update_deepmd_input(mt_config, warning=False) + mt_config = normalize(mt_config, multi_task=True) + + mt_trainer = get_trainer(deepcopy(mt_config), shared_links=shared_links) + mt_desc = mt_trainer.wrapper.model["model_1"].atomic_model.descriptor + mt_fit_1 = mt_trainer.wrapper.model["model_1"].atomic_model.fitting_net + mt_fit_2 = mt_trainer.wrapper.model["model_2"].atomic_model.fitting_net + + # Verify descriptor params are aliased (share_params) + mt_desc_2 = mt_trainer.wrapper.model["model_2"].atomic_model.descriptor + for (n1, p1), (_n2, p2) in zip( + mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True + ): + assert p1.data_ptr() == p2.data_ptr(), ( + f"Descriptor params not aliased: {n1}" + ) + + # Record grads for descriptor + both fitting heads + mt_grads: list[dict[str, torch.Tensor]] = [] + mt_trainer._optimizer_step = _make_recording_step( + mt_trainer, + {"desc": mt_desc, "fit_1": mt_fit_1, "fit_2": mt_fit_2}, + mt_grads, + ) + with patch( + "deepmd.utils.random.choice", + _deterministic_task_choice([0, 1]), + ): + mt_trainer.run() # 2 steps: task_0 then task_1 + + assert len(mt_grads) == 2 + + # ===== Single-task trainer for task_0 ===== + st0_config = _make_gradient_test_st_config( + self.data_dir_0, + fitting_seed=1, # same as model_1 + numb_fparam=self.nfparam, + numb_aparam=self.naparam, + ) + st0_config = update_deepmd_input(st0_config, warning=False) + st0_config = normalize(st0_config) + + os.chdir(tempfile.mkdtemp(dir=self.tmpdir)) # fresh cwd + st0_trainer = get_trainer(deepcopy(st0_config)) + + # Copy MT model_1 state → ST0 to ensure identical params+buffers + # (stat buffers like davg/dstd/bias_atom_e differ due to data) + mt_m1 = mt_trainer.wrapper.model["model_1"] + st0_m = st0_trainer.wrapper.model["Default"] + st0_m.load_state_dict(mt_m1.state_dict()) + + st0_desc = st0_m.atomic_model.descriptor + st0_fit = st0_m.atomic_model.fitting_net + + st0_grads: list[dict[str, torch.Tensor]] = [] + st0_trainer._optimizer_step = _make_recording_step( + st0_trainer, {"desc": st0_desc, "fit": st0_fit}, st0_grads + ) + st0_trainer.run() # 1 step + assert len(st0_grads) == 1 + + # ===== Single-task trainer for task_1 ===== + st1_config = _make_gradient_test_st_config( + self.data_dir_1, + fitting_seed=2, # same as model_2 + numb_fparam=self.nfparam, + numb_aparam=self.naparam, + ) + st1_config = update_deepmd_input(st1_config, warning=False) + st1_config = normalize(st1_config) + + os.chdir(tempfile.mkdtemp(dir=self.tmpdir)) # fresh cwd + st1_trainer = get_trainer(deepcopy(st1_config)) + + # Copy MT model_2 state → ST1 to ensure identical params+buffers + mt_m2 = mt_trainer.wrapper.model["model_2"] + st1_m = st1_trainer.wrapper.model["Default"] + st1_m.load_state_dict(mt_m2.state_dict()) + + st1_desc = st1_m.atomic_model.descriptor + st1_fit = st1_m.atomic_model.fitting_net + + st1_grads: list[dict[str, torch.Tensor]] = [] + st1_trainer._optimizer_step = _make_recording_step( + st1_trainer, {"desc": st1_desc, "fit": st1_fit}, st1_grads + ) + st1_trainer.run() # 1 step + assert len(st1_grads) == 1 + + # ===== Comparison: descriptor gradients ===== + # Multi-task descriptor grad at each step should match single-task + desc_keys = [k for k in mt_grads[0] if k.startswith("desc/")] + assert len(desc_keys) > 0, "No descriptor gradients" + + # Per-task descriptor grad: mt step_0 == st_0, mt step_1 == st_1 + for name in desc_keys: + np.testing.assert_allclose( + mt_grads[0][name].detach().cpu().numpy(), + st0_grads[0][name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Task_0 descriptor grad mismatch: {name}", + ) + np.testing.assert_allclose( + mt_grads[1][name].detach().cpu().numpy(), + st1_grads[0][name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Task_1 descriptor grad mismatch: {name}", + ) + + # Descriptor grad sum: mt(step0 + step1) == st0 + st1 + for name in desc_keys: + mt_sum = mt_grads[0][name] + mt_grads[1][name] + st_sum = st0_grads[0][name] + st1_grads[0][name] + np.testing.assert_allclose( + mt_sum.detach().cpu().numpy(), + st_sum.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Descriptor grad sum mismatch: {name}", + ) + + # ===== Comparison: fitting head gradients ===== + # Step 0 trains model_1 → mt fit_1 grads == st0 fit grads + fit1_keys = [k for k in mt_grads[0] if k.startswith("fit_1/")] + for name in fit1_keys: + st_name = name.replace("fit_1/", "fit/") + np.testing.assert_allclose( + mt_grads[0][name].detach().cpu().numpy(), + st0_grads[0][st_name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Fitting head grad mismatch (task_0): {name}", + ) + # Verify fit_2 not in step 0 grads (not part of computation graph) + assert not any(k.startswith("fit_2/") for k in mt_grads[0]), ( + "fit_2 should have no gradients in step 0 (task_0)" + ) + + # Step 1 trains model_2 → mt fit_2 grads == st1 fit grads + fit2_keys = [k for k in mt_grads[1] if k.startswith("fit_2/")] + for name in fit2_keys: + st_name = name.replace("fit_2/", "fit/") + np.testing.assert_allclose( + mt_grads[1][name].detach().cpu().numpy(), + st1_grads[0][st_name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Fitting head grad mismatch (task_1): {name}", + ) + # Verify fit_1 not in step 1 grads + assert not any(k.startswith("fit_1/") for k in mt_grads[1]), ( + "fit_1 should have no gradients in step 1 (task_1)" + ) + + +class TestCompileCaseEmbdVaryingNframes(unittest.TestCase): + """Compiled multi-task with ``dim_case_embd > 0`` and varying ``nframes``. + + The shared-fitting path in ``GeneralFitting.call`` tiles the per-task + case embedding as ``xp.tile(reshape(case_embd, (1, 1, -1)), (nf, nloc, 1))`` + (see ``deepmd/dpmodel/fitting/general_fitting.py``). Under + ``tracing_mode="symbolic"`` the ``nf`` multiplier must stay symbolic; + otherwise the compiled graph hard-codes a specific batch size and + subsequent calls with a different ``nframes`` error out. + + The test uses two systems with different atom counts and per-system + ``batch_size=[2, 3]`` so every branch's compiled graph sees both + nframes values. ``dim_case_embd=2`` is deliberately chosen to also + collide numerically with the nframes=2 runtime case. ``dp_random.choice`` + is mocked so both tasks and both systems are sampled. + """ + + @classmethod + def setUpClass(cls) -> None: + cls.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_case_embd_vary_") + cls.sys0_m1 = os.path.join(cls.tmpdir, "sys0_model1_6atoms") + cls.sys1_m1 = os.path.join(cls.tmpdir, "sys1_model1_4atoms") + cls.sys0_m2 = os.path.join(cls.tmpdir, "sys0_model2_6atoms") + cls.sys1_m2 = os.path.join(cls.tmpdir, "sys1_model2_4atoms") + for path, seed in ( + (cls.sys0_m1, 11), + (cls.sys1_m1, 12), + (cls.sys0_m2, 21), + (cls.sys1_m2, 22), + ): + _generate_random_data_dir( + path, + atom_types=[i % 2 for i in range(6 if "6atoms" in path else 4)], + nframes=4, + seed=seed, + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def _make_config(self) -> dict: + type_map = ["O", "H"] + fitting = deepcopy(_fitting_net) + fitting["dim_case_embd"] = 2 + shared_dict: dict = { + "my_type_map": type_map, + "my_descriptor": deepcopy(_descriptor_se_e2_a), + "my_fitting": fitting, + } + config = { + "model": { + "shared_dict": shared_dict, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "enable_compile": True, + "model_prob": {"model_1": 0.5, "model_2": 0.5}, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [self.sys0_m1, self.sys1_m1], + "batch_size": [2, 3], + }, + "validation_data": { + "systems": [self.sys0_m1], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [self.sys0_m2, self.sys1_m2], + "batch_size": [2, 3], + }, + "validation_data": { + "systems": [self.sys0_m2], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + return config, shared_links + + def test_compiled_varying_nframes_with_case_embd(self) -> None: + """Compiled shared-fitting graph handles nframes in {2, 3} per branch.""" + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + config, shared_links = self._make_config() + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_case_embd_run_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(deepcopy(config), shared_links=shared_links) + # Both branches must be compiled. + for mk in ("model_1", "model_2"): + self.assertIsInstance(trainer.wrapper.model[mk], _CompiledModel) + ce = trainer.wrapper.model[ + mk + ].original_model.atomic_model.fitting_net.case_embd + self.assertIsNotNone(ce, f"case_embd not set on {mk}") + self.assertEqual(int(ce.shape[0]), 2) + + # Drive 6 steps alternating (task, system_index) so each branch's + # compiled graph sees both nframes=2 (sys0) and nframes=3 (sys1). + trainer.wrapper.train() + task_sequence = ["model_1", "model_2"] * 3 + sys_sequence = [0, 1, 0, 1, 0, 1] + sys_iter = iter(sys_sequence) + + original_choice = dp_random.choice + + def task_or_system_choice(a, size=None, replace=True, p=None): + # Per-branch system selection: alternate between the two + # systems so every compiled graph sees both nframes values. + if hasattr(a, "__len__") and len(a) == 2 and p is not None: + return next(sys_iter) + return original_choice(a, size=size, replace=replace, p=p) + + seen_nframes: set[int] = set() + with patch.object(dp_random, "choice", side_effect=task_or_system_choice): + for task_key in task_sequence: + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True, task_key=task_key) + seen_nframes.add(int(inp["coord"].shape[0])) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper( + **inp, cur_lr=lr, label=lab, task_key=task_key + ) + loss.backward() + trainer.optimizer.step() + self.assertFalse(torch.isnan(loss), "loss is NaN") + self.assertFalse(torch.isinf(loss), "loss is Inf") + + self.assertEqual( + seen_nframes, + {2, 3}, + msg=( + f"nframes did not vary across steps: {seen_nframes}. " + "Expected both 2 and 3 (matching and not matching dim_case_embd=2)." + ), + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 3b3ab247bb..8c7f7e2a76 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -12,6 +12,9 @@ import shutil import tempfile import unittest +from unittest.mock import ( + patch, +) import torch @@ -37,6 +40,121 @@ "water", ) +# Keys present on the compiled path. ``atom_virial`` is intentionally excluded: +# training never passes ``do_atomic_virial=True``, so the compiled graph is +# traced with the default (False) and per-atom virial is not emitted. +_COMPILE_PRED_KEYS = ("atom_energy", "energy", "force", "virial") +_COMPILE_TOL = {"atol": 1e-10, "rtol": 1e-10} + +# DPA3 descriptor config used to extend the varying-natoms compile-correctness +# test to a non-trivial architecture (repflow with attention). ``precision: +# float64`` is set explicitly so the strict ``atol=rtol=1e-10`` comparison +# holds at machine epsilon. +# +# DPA1 (se_atten) is intentionally NOT covered here: inductor's compile of the +# se_atten attention path is intermittently incorrect — see the "known +# limitations" section of the multi-task compile memo for details. +_DESCRIPTOR_DPA2 = { + "type": "dpa2", + "repinit": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 18, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "activation_function": "tanh", + "use_three_body": True, + "three_body_sel": 12, + "three_body_rcut": 3.0, + "three_body_rcut_smth": 0.5, + }, + "repformer": { + "rcut": 3.0, + "rcut_smth": 0.5, + "nsel": 12, + "nlayers": 2, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_neuron": 4, + "update_h2": False, + "update_g1_has_conv": True, + "update_g1_has_grrg": True, + "update_g1_has_drrd": True, + "update_g1_has_attn": True, + "update_g2_has_g1g1": True, + "update_g2_has_attn": True, + "attn2_has_gate": True, + }, + "precision": "float64", + "seed": 1, + "add_tebd_to_repinit_out": False, +} + +_DESCRIPTOR_DPA3 = { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 5, + "a_dim": 4, + "nlayers": 2, + "e_rcut": 3.0, + "e_rcut_smth": 0.5, + "e_sel": 12, + "a_rcut": 3.0, + "a_rcut_smth": 0.5, + "a_sel": 8, + "axis_neuron": 4, + }, + "precision": "float64", + "concat_output_tebd": False, + "seed": 1, +} + + +def _assert_compile_predictions_match( + testcase: unittest.TestCase, + out_c: dict, + out_uc: dict, + *, + ctx: str = "", +) -> None: + for key in _COMPILE_PRED_KEYS: + testcase.assertIn(key, out_uc, f"{ctx}uncompiled missing '{key}'") + testcase.assertIn(key, out_c, f"{ctx}compiled missing '{key}'") + torch.testing.assert_close( + out_c[key], + out_uc[key], + **_COMPILE_TOL, + msg=f"{ctx}{key} mismatch between compiled and uncompiled", + ) + + +def _assert_compile_grads_match( + testcase: unittest.TestCase, + model_c: torch.nn.Module, + model_uc: torch.nn.Module, + *, + ctx: str = "", +) -> None: + for (name_uc, p_uc), (_, p_c) in zip( + model_uc.named_parameters(), + model_c.named_parameters(), + strict=True, + ): + if p_uc.grad is None: + continue + testcase.assertIsNotNone(p_c.grad, msg=f"{ctx}grad is None for {name_uc}") + torch.testing.assert_close( + p_c.grad, + p_uc.grad, + **_COMPILE_TOL, + msg=f"{ctx}grad mismatch on {name_uc}", + ) + def _make_config(data_dir: str, numb_steps: int = 5) -> dict: """Build a minimal config dict pointing at *data_dir*.""" @@ -163,9 +281,19 @@ def test_training_loop_compiled(self) -> None: config = normalize(config) self._run_training(config) + def test_training_loop_compiled_silu(self) -> None: + """Run compiled training with silu activation.""" + config = _make_config(self.data_dir, numb_steps=5) + config["model"]["descriptor"]["activation_function"] = "silu" + config["model"]["fitting_net"]["activation_function"] = "silu" + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + self._run_training(config) + -class TestCompiledRecompile(unittest.TestCase): - """Test that _CompiledModel recompiles when nall exceeds max_nall.""" +class TestCompiledDynamicShapes(unittest.TestCase): + """Test that _CompiledModel handles varying nall via dynamic shapes.""" @classmethod def setUpClass(cls) -> None: @@ -174,8 +302,13 @@ def setUpClass(cls) -> None: raise unittest.SkipTest(f"Example data not found: {data_dir}") cls.data_dir = data_dir - def test_nall_growth_triggers_recompile(self) -> None: - """Shrink max_nall to force a recompile, then verify training works.""" + def test_compiled_handles_varying_nall(self) -> None: + """Run several training steps, assert finite loss each step. + + With ``tracing_mode="symbolic"`` + ``dynamic=True``, nall is a + symbolic dim so nall growth across batches is handled without + any recompile or padding. + """ from deepmd.pt_expt.train.training import ( _CompiledModel, ) @@ -185,7 +318,7 @@ def test_nall_growth_triggers_recompile(self) -> None: config = update_deepmd_input(config, warning=False) config = normalize(config) - tmpdir = tempfile.mkdtemp(prefix="pt_expt_recompile_") + tmpdir = tempfile.mkdtemp(prefix="pt_expt_dynamic_") try: old_cwd = os.getcwd() os.chdir(tmpdir) @@ -193,39 +326,21 @@ def test_nall_growth_triggers_recompile(self) -> None: trainer = get_trainer(config) # The wrapper.model should be a _CompiledModel - compiled_model = trainer.wrapper.model + compiled_model = trainer.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) - original_max_nall = compiled_model._max_nall - self.assertGreater(original_max_nall, 0) - - # Artificially shrink max_nall to 1 so the next batch - # will certainly exceed it and trigger recompilation. - compiled_model._max_nall = 1 - old_compiled_lower = compiled_model.compiled_forward_lower - - # Run one training step — should trigger recompile trainer.wrapper.train() - trainer.optimizer.zero_grad(set_to_none=True) - inp, lab = trainer.get_data(is_train=True) - lr = trainer.scheduler.get_last_lr()[0] - _, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab) - loss.backward() - trainer.optimizer.step() - - # max_nall should have grown beyond 1 - new_max_nall = compiled_model._max_nall - self.assertGreater(new_max_nall, 1) - - # compiled_forward_lower should be a new object - self.assertIsNot( - compiled_model.compiled_forward_lower, - old_compiled_lower, - ) - - # Loss should be a finite scalar - self.assertFalse(torch.isnan(loss)) - self.assertFalse(torch.isinf(loss)) + for _ in range(3): + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer.optimizer.step() + + # Loss should be a finite scalar at every step + self.assertFalse(torch.isnan(loss)) + self.assertFalse(torch.isinf(loss)) finally: os.chdir(old_cwd) finally: @@ -242,38 +357,42 @@ def setUpClass(cls) -> None: raise unittest.SkipTest(f"Example data not found: {data_dir}") cls.data_dir = data_dir - def test_compiled_matches_uncompiled(self) -> None: - """Energy, force, virial from compiled model must match uncompiled.""" + def _check_consistency(self, activation: str | None = None) -> None: + """Compiled model predictions match uncompiled for the given activation. + + ``activation`` overrides both descriptor and fitting-net activation + functions when provided. ``None`` keeps the config default (tanh). + """ from deepmd.pt_expt.train.training import ( _CompiledModel, ) - config = _make_config(self.data_dir, numb_steps=1) - # enable virial in loss so the model returns it - config["loss"]["start_pref_v"] = 1.0 - config["loss"]["limit_pref_v"] = 1.0 - config = update_deepmd_input(config, warning=False) - config = normalize(config) + def _build_config(enable_compile: bool) -> dict: + config = _make_config(self.data_dir, numb_steps=1) + # enable virial in loss so the model returns it + config["loss"]["start_pref_v"] = 1.0 + config["loss"]["limit_pref_v"] = 1.0 + if activation is not None: + config["model"]["descriptor"]["activation_function"] = activation + config["model"]["fitting_net"]["activation_function"] = activation + if enable_compile: + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + return normalize(config) tmpdir = tempfile.mkdtemp(prefix="pt_expt_consistency_") try: old_cwd = os.getcwd() os.chdir(tmpdir) try: - trainer = get_trainer(config) + trainer = get_trainer(_build_config(enable_compile=False)) # Uncompiled model reference uncompiled_model = trainer.model uncompiled_model.eval() # Build compiled model from the same weights - config_compiled = _make_config(self.data_dir, numb_steps=1) - config_compiled["loss"]["start_pref_v"] = 1.0 - config_compiled["loss"]["limit_pref_v"] = 1.0 - config_compiled["training"]["enable_compile"] = True - config_compiled = update_deepmd_input(config_compiled, warning=False) - config_compiled = normalize(config_compiled) - trainer_compiled = get_trainer(config_compiled) - compiled_model = trainer_compiled.wrapper.model + trainer_compiled = get_trainer(_build_config(enable_compile=True)) + compiled_model = trainer_compiled.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) # Copy uncompiled weights to compiled model so they match @@ -297,34 +416,79 @@ def test_compiled_matches_uncompiled(self) -> None: pred_c = compiled_model(coord.clone(), atype, box) - # Energy - torch.testing.assert_close( - pred_c["energy"], - pred_uc["energy"], - atol=1e-10, - rtol=1e-10, - msg="energy mismatch between compiled and uncompiled", + _assert_compile_predictions_match(self, pred_c, pred_uc) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compiled_matches_uncompiled(self) -> None: + """Energy, force, virial from compiled model must match uncompiled.""" + self._check_consistency() + + def test_compiled_matches_uncompiled_silu(self) -> None: + """Same numerical equivalence under silu activation (full model).""" + self._check_consistency(activation="silu") + + def test_compiled_gradients_match_uncompiled(self) -> None: + """Parameter gradients from compiled model must match uncompiled. + + Verifies second-order derivatives are correct: the loss includes + force terms, and force is computed via autograd.grad(create_graph=True), + so loss.backward() requires second-order differentiation through the + make_fx-decomposed backward ops. + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + config_uc = _make_config(self.data_dir, numb_steps=1) + config_uc = update_deepmd_input(config_uc, warning=False) + config_uc = normalize(config_uc) + + config_c = _make_config(self.data_dir, numb_steps=1) + config_c["training"]["enable_compile"] = True + config_c = update_deepmd_input(config_c, warning=False) + config_c = normalize(config_c) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_grad_consistency_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer_uc = get_trainer(config_uc) + trainer_c = get_trainer(config_c) + compiled_model = trainer_c.wrapper.model["Default"] + self.assertIsInstance(compiled_model, _CompiledModel) + + # Match weights + compiled_model.original_model.load_state_dict( + trainer_uc.model.state_dict() ) - # Force - self.assertIn("force", pred_c, "compiled model missing 'force'") - self.assertIn("force", pred_uc, "uncompiled model missing 'force'") - torch.testing.assert_close( - pred_c["force"], - pred_uc["force"], - atol=1e-10, - rtol=1e-10, - msg="force mismatch between compiled and uncompiled", + + # Forward + backward through wrapper (includes loss) + trainer_uc.optimizer.zero_grad(set_to_none=True) + trainer_c.optimizer.zero_grad(set_to_none=True) + + input_dict, label_dict = trainer_uc.get_data(is_train=True) + cur_lr = trainer_uc.scheduler.get_last_lr()[0] + + _, loss_uc, _ = trainer_uc.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + ) + _, loss_c, _ = trainer_c.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + ) + loss_uc.backward() + loss_c.backward() + + _assert_compile_grads_match( + self, compiled_model.original_model, trainer_uc.model ) - # Virial - if "virial" in pred_uc: - self.assertIn("virial", pred_c, "compiled model missing 'virial'") - torch.testing.assert_close( - pred_c["virial"], - pred_uc["virial"], - atol=1e-10, - rtol=1e-10, - msg="virial mismatch between compiled and uncompiled", - ) finally: os.chdir(old_cwd) finally: @@ -496,7 +660,7 @@ def test_restart_with_compile(self) -> None: trainer2 = get_trainer(config2, restart_model=ckpt_path) self.assertEqual(trainer2.start_step, 5) - self.assertIsInstance(trainer2.wrapper.model, _CompiledModel) + self.assertIsInstance(trainer2.wrapper.model["Default"], _CompiledModel) trainer2.run() with open(os.path.join(tmpdir, "lcurve.out")) as f: @@ -626,5 +790,390 @@ def test_training_loop(self) -> None: shutil.rmtree(tmpdir, ignore_errors=True) +class TestCompiledVaryingNframesWithParams(unittest.TestCase): + """Compiled training with varying ``nframes`` + ``nall`` + fparam/aparam. + + Exercises the compiled forward path under all three kinds of shape + variation simultaneously: + + * Different systems have different atom counts -> varying ``nloc`` / ``nall``. + * Per-system ``batch_size: [2, 3]`` -> varying ``nframes`` (2 vs 3). + * Both ``fparam`` (per-frame) and ``aparam`` (per-atom) labels are + provided, covering the ``dim_fparam`` / ``dim_aparam`` > 0 branches + inside ``forward_lower``. + + The chosen values (``nframes`` in {2, 3}, ``numb_fparam=2``, + ``numb_aparam=3``) are deliberately chosen so the runtime ``nframes`` + collides with the per-frame / per-atom feature dims — this is the + exact pattern that previously caused PyTorch's symbolic tracer to + specialise the batch dim (see _trace_and_compile in training.py). + + ``dp_random.choice`` is mocked to alternate between the two systems + so both are guaranteed to be sampled across ``nsteps``. + """ + + NFPARAM = 2 + NAPARAM = 3 + + @classmethod + def setUpClass(cls) -> None: + # Reuse the data-dir helper from the multitask gradient tests so we + # don't duplicate the npy/raw layout boilerplate. + from .test_multitask import ( + _generate_random_data_dir, + ) + + cls.tmpdir = tempfile.mkdtemp(prefix="pt_expt_varying_params_data_") + cls.sys0 = os.path.join(cls.tmpdir, "sys0_8atoms") + cls.sys1 = os.path.join(cls.tmpdir, "sys1_4atoms") + # Atom types alternate 0/1 to match the ["O", "H"] type_map below. + _generate_random_data_dir( + cls.sys0, + atom_types=[i % 2 for i in range(8)], + nframes=4, + seed=42, + nfparam=cls.NFPARAM, + naparam=cls.NAPARAM, + ) + _generate_random_data_dir( + cls.sys1, + atom_types=[i % 2 for i in range(4)], + nframes=4, + seed=137, + nfparam=cls.NFPARAM, + naparam=cls.NAPARAM, + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def _make_config(self, enable_compile: bool) -> dict: + config = { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "numb_fparam": self.NFPARAM, + "numb_aparam": self.NAPARAM, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [self.sys0, self.sys1], + # Per-system batch sizes: sys0 gets nframes=2, sys1 gets nframes=3. + # Combined with sys0=8 atoms / sys1=4 atoms this guarantees + # both `nframes` and `nall` vary across steps. + "batch_size": [2, 3], + }, + "validation_data": { + "systems": [self.sys0], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 6, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + if enable_compile: + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + return config + + def _run_steps(self, enable_compile: bool, nsteps: int = 6) -> None: + from deepmd.utils import data_system as _data_system + + config = self._make_config(enable_compile=enable_compile) + sys_sequence = [i % 2 for i in range(nsteps)] + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_varying_params_run_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + if enable_compile: + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + self.assertIsInstance( + trainer.wrapper.model["Default"], _CompiledModel + ) + + trainer.wrapper.train() + seen_nframes = set() + seen_nall = set() + with patch.object( + _data_system.dp_random, + "choice", + side_effect=sys_sequence, + ): + for _ in range(nsteps): + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + seen_nframes.add(int(inp["coord"].shape[0])) + seen_nall.add(int(inp["atype"].shape[1])) + # fparam/aparam must be present in every batch + self.assertIn("fparam", inp) + self.assertIn("aparam", inp) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer.optimizer.step() + self.assertFalse(torch.isnan(loss), "loss is NaN") + self.assertFalse(torch.isinf(loss), "loss is Inf") + + # The two systems differ in both batch-size-auto and natoms, + # so both nframes and nloc should have varied across steps. + self.assertGreater( + len(seen_nframes), + 1, + msg=f"nframes did not vary across steps: {seen_nframes}", + ) + self.assertGreater( + len(seen_nall), + 1, + msg=f"nloc did not vary across steps: {seen_nall}", + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compiled(self) -> None: + """Compiled training with varying nframes + fparam/aparam.""" + self._run_steps(enable_compile=True) + + def test_uncompiled(self) -> None: + """Baseline: same config, uncompiled, should also succeed.""" + self._run_steps(enable_compile=False) + + +def _create_small_system( + path: str, natoms_o: int = 2, natoms_h: int = 4, nframes: int = 10 +) -> None: + """Create a minimal deepmd data system with few atoms.""" + import numpy as np + + natoms = natoms_o + natoms_h + set_dir = os.path.join(path, "set.000") + os.makedirs(set_dir, exist_ok=True) + + with open(os.path.join(path, "type.raw"), "w") as f: + for _ in range(natoms_o): + f.write("0\n") + for _ in range(natoms_h): + f.write("1\n") + with open(os.path.join(path, "type_map.raw"), "w") as f: + f.write("O\nH\n") + + rng = np.random.default_rng(42) + box_len = 5.0 + box = np.zeros((nframes, 9), dtype=np.float32) + box[:, 0] = box_len + box[:, 4] = box_len + box[:, 8] = box_len + coord = rng.uniform(0, box_len, size=(nframes, natoms * 3)).astype(np.float32) + energy = rng.normal(-100, 10, size=(nframes,)).astype(np.float32) + force = rng.normal(0, 1, size=(nframes, natoms * 3)).astype(np.float32) + virial = rng.normal(0, 1, size=(nframes, 9)).astype(np.float32) + np.save(os.path.join(set_dir, "coord.npy"), coord) + np.save(os.path.join(set_dir, "force.npy"), force) + np.save(os.path.join(set_dir, "energy.npy"), energy) + np.save(os.path.join(set_dir, "box.npy"), box) + np.save(os.path.join(set_dir, "virial.npy"), virial) + + +class TestCompiledVaryingNatoms(unittest.TestCase): + """Test compiled training with systems of different atom counts. + + Uses the 192-atom ``data_0`` alongside a synthetic 6-atom system so that + different ``nloc`` / ``nall`` appear across steps, exercising the + dynamic-shape compile path. + + ``dp_random.choice`` is mocked to alternate [0, 1, 0, 1, ...] so that + both systems are guaranteed to be sampled. + + ``batch_size: "auto"`` assigns different batch sizes per system (based + on atom count), so both ``nframes`` and ``natoms`` vary across steps. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + cls.small_data_dir = tempfile.mkdtemp(prefix="pt_expt_small_data_") + _create_small_system(cls.small_data_dir) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.small_data_dir, ignore_errors=True) + + def _make_varying_config( + self, + enable_compile: bool, + descriptor: dict | None = None, + ) -> dict: + """Config with two systems of different natoms and auto batch size. + + ``descriptor`` overrides the default se_e2_a descriptor when given. + """ + config = _make_config(self.data_dir) + config["training"]["training_data"]["systems"].append(self.small_data_dir) + config["training"]["training_data"]["batch_size"] = "auto" + # enable virial in loss so the model returns it (virial.npy exists in + # both systems), exercising the compiled virial passthrough on each step + config["loss"]["start_pref_v"] = 1.0 + config["loss"]["limit_pref_v"] = 1.0 + if descriptor is not None: + config["model"]["descriptor"] = descriptor + if enable_compile: + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + return config + + def _check_varying_natoms(self, descriptor: dict | None = None) -> None: + """Per-step compiled-vs-uncompiled comparison for the given descriptor. + + The loss config has ``start_pref_f=1000`` and ``start_pref_v=1.0``, + so ``loss.backward()`` propagates through ``F = -dE/dr`` (computed + via ``autograd.grad(..., create_graph=True)``); the per-parameter + grad comparison therefore exercises the second-order derivative + ``d^2 E / (dr d theta)`` on each step at each system size. + + Verifies multi-step training-trajectory equivalence: weights are + synced once at the start, then both trainers step their own Adam + states forward. All assertions use the strict + ``atol=rtol=1e-10`` tolerance; if a descriptor's compiled path + cannot meet that on float64 the descriptor has a real numerical + problem (see the DPA1 limitation note where this happened). + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + nsteps = 4 + # Alternate between system 0 (192 atoms) and system 1 (6 atoms) + sys_sequence = [i % 2 for i in range(nsteps)] + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_varying_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer_uc = get_trainer(self._make_varying_config(False, descriptor)) + trainer_c = get_trainer(self._make_varying_config(True, descriptor)) + compiled_model = trainer_c.wrapper.model["Default"] + self.assertIsInstance(compiled_model, _CompiledModel) + + # Sync weights so predictions can be compared exactly + compiled_model.original_model.load_state_dict( + trainer_uc.model.state_dict() + ) + trainer_uc.wrapper.train() + trainer_c.wrapper.train() + + with patch( + "deepmd.utils.data_system.dp_random.choice", + side_effect=sys_sequence, + ): + for step in range(nsteps): + trainer_uc.optimizer.zero_grad(set_to_none=True) + trainer_c.optimizer.zero_grad(set_to_none=True) + + # Single shared batch; mock yields one value per call + inp, lab = trainer_uc.get_data(is_train=True) + lr = trainer_uc.scheduler.get_last_lr()[0] + + out_uc, loss_uc, _ = trainer_uc.wrapper( + **inp, cur_lr=lr, label=lab + ) + out_c, loss_c, _ = trainer_c.wrapper( + **inp, cur_lr=lr, label=lab + ) + + ctx = f"step={step} " + _assert_compile_predictions_match(self, out_c, out_uc, ctx=ctx) + torch.testing.assert_close( + loss_c, + loss_uc, + **_COMPILE_TOL, + msg=f"{ctx}loss mismatch", + ) + + loss_uc.backward() + loss_c.backward() + _assert_compile_grads_match( + self, + compiled_model.original_model, + trainer_uc.model, + ctx=ctx, + ) + + trainer_uc._optimizer_step() + trainer_c._optimizer_step() + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compiled_matches_uncompiled_varying_natoms_se_e2_a(self) -> None: + """se_e2_a: compiled vs uncompiled match across varying nframes/natoms.""" + self._check_varying_natoms() # uses default se_e2_a from _make_config + + def test_compiled_matches_uncompiled_varying_natoms_dpa2(self) -> None: + """DPA2: compiled vs uncompiled match across varying nframes/natoms. + + Exercises the DPA2 repinit + repformers stack; matches at machine + epsilon (~1e-12) on float64 just like se_e2_a. + """ + self._check_varying_natoms(_DESCRIPTOR_DPA2) + + def test_compiled_matches_uncompiled_varying_natoms_dpa3(self) -> None: + """DPA3: compiled vs uncompiled match across varying nframes/natoms. + + Exercises a non-trivial multi-layer repflow descriptor; matches at + machine epsilon (~1e-12) on float64 just like se_e2_a. + """ + self._check_varying_natoms(_DESCRIPTOR_DPA3) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt_expt/test_training_ddp.py b/source/tests/pt_expt/test_training_ddp.py new file mode 100644 index 0000000000..38771968ac --- /dev/null +++ b/source/tests/pt_expt/test_training_ddp.py @@ -0,0 +1,1872 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for distributed (DDP) training in the pt_expt backend. + +Uses ``torch.multiprocessing.spawn`` + ``gloo`` backend to verify DDP on CPU. + +Verifies that: +1. Single-task DDP training completes and produces correct outputs +2. Multi-task DDP training completes and produces correct outputs +3. DDP gradient averaging matches manual average of per-rank gradients +4. Multi-task DDP gradient averaging works correctly +5. Finetune + DDP: selective weight copy via _unwrapped +6. Finetune + DDP with random fitting: descriptor from pretrained, fitting random +7. Finetune + DDP with new type: exercises _unwrapped.model["Default"] + stat broadcast +8. DDP + torch.compile: single-task and multi-task compile under DDP +""" + +import os +import shutil +import socket +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, +) +from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) + +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent / "pt" / "water" / "data" / "data_0") + +EXAMPLE_DIR = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "examples", + "water", +) + + +def _find_free_port(): + """Find a free TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _make_config(data_dir: str, numb_steps: int = 2) -> dict: + """Build a minimal single-task config.""" + return { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _make_multitask_config(data_dir: str, numb_steps: int = 2) -> dict: + """Build a minimal multi-task config with shared descriptor.""" + descriptor = { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + } + fitting = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + } + return { + "model": { + "shared_dict": { + "my_type_map": ["O", "H"], + "my_descriptor": deepcopy(descriptor), + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _make_dpa1_config(data_dir: str, numb_steps: int = 2) -> dict: + """Build a minimal DPA1 config (mixed_types) for finetune new-type tests.""" + return { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa1", + "sel": 12, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [4, 8], + "axis_neuron": 4, + "attn": 4, + "attn_layer": 1, + "attn_dotr": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [8, 8], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _subsample_data(src_dir: str, dst_dir: str, nframes: int = 2) -> None: + """Copy a data system, keeping only the first *nframes* frames.""" + shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True) + set_dir = os.path.join(dst_dir, "set.000") + for name in os.listdir(set_dir): + if name.endswith(".npy"): + arr = np.load(os.path.join(set_dir, name)) + np.save(os.path.join(set_dir, name), arr[:nframes]) + + +# --------------------------------------------------------------------------- +# Worker functions for mp.spawn +# --------------------------------------------------------------------------- + + +def _worker_single_task_train(rank, world_size, port, data_dir, result_dict): + """Worker: run single-task DDP training.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_st_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + # Collect results + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + # Get final weights + weights = { + name: p.detach().cpu().clone() + for name, p in trainer._unwrapped.named_parameters() + } + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "weights": weights, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_multitask_train(rank, world_size, port, data_dir, result_dict): + """Worker: run multi-task DDP training.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_multitask_config(data_dir, numb_steps=2) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + # Get shared descriptor params from model_1 + desc_params = {} + for name, p in trainer._unwrapped.model[ + "model_1" + ].atomic_model.descriptor.named_parameters(): + desc_params[name] = p.detach().cpu().clone() + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "desc_params": desc_params, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_gradient_test(rank, world_size, port, data_dir, result_dict): + """Worker: run 1 step of DDP training, collect gradients and input data.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_grad_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=1) + config["model"]["descriptor"]["precision"] = "float64" + config["model"]["fitting_net"]["precision"] = "float64" + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + + # Run one forward/backward step manually + trainer.wrapper.train() + trainer.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict = trainer.get_data(is_train=True, task_key="Default") + + cur_lr_sched = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper( + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + ) + loss.backward() # DDP all-reduces gradients here + + # Collect post-all-reduce gradients + grads = {} + for name, p in trainer._unwrapped.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().cpu().clone() + + # Collect input batch (for single-process replay) + batch = {} + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.detach().cpu().clone() + else: + batch[k] = v + for k, v in label_dict.items(): + if isinstance(v, torch.Tensor): + batch[f"label_{k}"] = v.detach().cpu().clone() + + # Initial model state dict (before any optimizer step) + init_state = { + k: v.detach().cpu().clone() + for k, v in trainer._unwrapped.state_dict().items() + if k != "_extra_state" + } + + result_dict[rank] = { + "grads": grads, + "batch": batch, + "init_state": init_state, + "config": config, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_multitask_gradient_test(rank, world_size, port, data_dir, result_dict): + """Worker: run 1 step of multi-task DDP training, collect gradients.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_grad_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_multitask_config(data_dir, numb_steps=1) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + + # Run one step with deterministic task selection + + # Force task_key = "model_1" for all ranks (deterministic) + trainer.wrapper.train() + trainer.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict = trainer.get_data(is_train=True, task_key="model_1") + cur_lr_sched = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper( + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + task_key="model_1", + ) + loss.backward() + + grads = {} + for name, p in trainer._unwrapped.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().cpu().clone() + + result_dict[rank] = { + "grads": grads, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_check_resume( + rank, world_size, port, data_dir, ckpt_path, numb_steps, is_restart, result_dict +): + """Worker: build DDP trainer from checkpoint, capture initial state, then train. + + Parameters + ---------- + is_restart : bool + True → restart_model (continue training, restore optimizer & step). + False → init_model (inherit weights, reset step to 0). + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_resume_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=numb_steps) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + if is_restart: + trainer = get_trainer(config, restart_model=ckpt_path) + else: + trainer = get_trainer(config, init_model=ckpt_path) + + # Capture initial state BEFORE training + init_weights = { + name: p.detach().cpu().clone() + for name, p in trainer._unwrapped.named_parameters() + } + start_step = trainer.start_step + init_lr = trainer.scheduler.get_last_lr()[0] + + trainer.run() + + result_dict[rank] = { + "init_weights": init_weights, + "start_step": start_step, + "init_lr": init_lr, + "lcurve_exists": os.path.exists("lcurve.out"), + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_finetune( + rank, world_size, port, ckpt_path, config_dict, model_branch, result_dict +): + """Worker: DDP finetune from checkpoint. + + Parameters + ---------- + ckpt_path : str + Absolute path to pretrained checkpoint (.pt). + config_dict : dict + Already normalized config with absolute data paths. + model_branch : str or None + ``"RANDOM"`` for random fitting, ``None`` for normal. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_ft_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = deepcopy(config_dict) + config["model"], finetune_links = get_finetune_rules( + ckpt_path, + config["model"], + model_branch=model_branch or "", + ) + + trainer = get_trainer( + config, + finetune_model=ckpt_path, + finetune_links=finetune_links, + ) + + # Capture state after finetune setup (before training) + init_state = { + k: v.detach().cpu().clone() + for k, v in trainer._unwrapped.state_dict().items() + if k != "_extra_state" + } + + trainer.run() + + result_dict[rank] = { + "init_state": init_state, + "lcurve_exists": os.path.exists("lcurve.out"), + "ckpt_files": [f for f in os.listdir(".") if f.endswith(".pt")], + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class TestDDPSingleTaskTrain(unittest.TestCase): + """Smoke test: single-task DDP training with 2 ranks.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + def test_ddp_single_task_trains(self) -> None: + """2 ranks, se_e2_a, 2 training steps — verify completion and outputs.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_single_task_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Only rank 0 produces lcurve.out and checkpoints + self.assertTrue(results[0]["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse( + results[1]["lcurve_exists"], "rank 1 should NOT produce lcurve.out" + ) + self.assertGreater(results[0]["num_ckpts"], 0, "rank 0 should save checkpoints") + self.assertEqual( + results[1]["num_ckpts"], 0, "rank 1 should NOT save checkpoints" + ) + + # Final weights should be identical across ranks + for name in results[0]["weights"]: + torch.testing.assert_close( + results[0]["weights"][name], + results[1]["weights"][name], + msg=f"Weights differ across ranks: {name}", + ) + + +class TestDDPMultiTaskTrain(unittest.TestCase): + """Smoke test: multi-task DDP training with 2 ranks.""" + + @classmethod + def setUpClass(cls) -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + cls.data_dir = _PT_DATA + + def test_ddp_multitask_trains(self) -> None: + """2 ranks, multi-task, 2 steps — verify completion.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Only rank 0 produces output files + self.assertTrue(results[0]["lcurve_exists"]) + self.assertFalse(results[1]["lcurve_exists"]) + self.assertGreater(results[0]["num_ckpts"], 0) + self.assertEqual(results[1]["num_ckpts"], 0) + + # Shared descriptor params should be identical across ranks + for name in results[0]["desc_params"]: + torch.testing.assert_close( + results[0]["desc_params"][name], + results[1]["desc_params"][name], + msg=f"Shared descriptor param differs across ranks: {name}", + ) + + +class TestDDPGradientAveraging(unittest.TestCase): + """Core DDP correctness: gradient averaging matches manual computation. + + Each DDP rank processes different data. After all-reduce, all ranks have + the averaged gradient. We verify: + 1. Both ranks have identical gradients (DDP guarantee) + 2. The DDP gradient equals the average of per-rank single-process gradients + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + def test_ddp_gradient_equals_average(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_gradient_test, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # 1. Verify gradients are identical on both ranks (DDP guarantee) + for name in r0["grads"]: + self.assertIn(name, r1["grads"], f"Grad key missing on rank 1: {name}") + torch.testing.assert_close( + r0["grads"][name], + r1["grads"][name], + atol=0, + rtol=0, + msg=f"Gradients should be identical across ranks: {name}", + ) + + # 2. Rebuild model in single process, replay each rank's batch, + # compute manual average, compare to DDP gradient + config = r0["config"] + tmpdir = tempfile.mkdtemp(prefix="ddp_grad_verify_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + # Load same initial state as DDP workers + state_to_load = dict(trainer._unwrapped.state_dict()) + for k in r0["init_state"]: + state_to_load[k] = r0["init_state"][k] + trainer._unwrapped.load_state_dict(state_to_load) + trainer.wrapper.train() + + # Forward+backward with rank 0's batch + trainer.optimizer.zero_grad(set_to_none=True) + input_0 = { + k: v.clone() + for k, v in r0["batch"].items() + if not k.startswith("label_") + } + label_0 = { + k[len("label_") :]: v.clone() + for k, v in r0["batch"].items() + if k.startswith("label_") + } + input_0["coord"] = input_0["coord"].requires_grad_(True) + cur_lr = trainer.scheduler.get_last_lr()[0] + _, loss_0, _ = trainer.wrapper(**input_0, cur_lr=cur_lr, label=label_0) + loss_0.backward() + grad_0 = { + name: p.grad.detach().clone() + for name, p in trainer._unwrapped.named_parameters() + if p.grad is not None + } + + # Forward+backward with rank 1's batch + trainer.optimizer.zero_grad(set_to_none=True) + input_1 = { + k: v.clone() + for k, v in r1["batch"].items() + if not k.startswith("label_") + } + label_1 = { + k[len("label_") :]: v.clone() + for k, v in r1["batch"].items() + if k.startswith("label_") + } + input_1["coord"] = input_1["coord"].requires_grad_(True) + _, loss_1, _ = trainer.wrapper(**input_1, cur_lr=cur_lr, label=label_1) + loss_1.backward() + grad_1 = { + name: p.grad.detach().clone() + for name, p in trainer._unwrapped.named_parameters() + if p.grad is not None + } + + # Expected = average of the two + for name in r0["grads"]: + if name in grad_0 and name in grad_1: + expected = (grad_0[name] + grad_1[name]) / 2.0 + torch.testing.assert_close( + r0["grads"][name], + expected, + atol=1e-10, + rtol=1e-10, + msg=f"DDP grad != avg(rank0, rank1) for {name}", + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestDDPMultiTaskGradient(unittest.TestCase): + """Verify DDP gradient averaging with multi-task training.""" + + @classmethod + def setUpClass(cls) -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + cls.data_dir = _PT_DATA + + def test_ddp_multitask_gradient(self) -> None: + """Both ranks pick same task; gradients should be identical after all-reduce.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_gradient_test, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Gradients should be identical across ranks + for name in r0["grads"]: + self.assertIn(name, r1["grads"], f"Grad key missing on rank 1: {name}") + torch.testing.assert_close( + r0["grads"][name], + r1["grads"][name], + atol=0, + rtol=0, + msg=f"Multi-task DDP gradients differ across ranks: {name}", + ) + + +class _DDPResumeBase(unittest.TestCase): + """Shared setup: train 2 steps in single process, save checkpoint + weights.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + cls._tmpdir = tempfile.mkdtemp(prefix="ddp_resume_setup_") + old_cwd = os.getcwd() + os.chdir(cls._tmpdir) + try: + config = _make_config(cls.data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + cls.ckpt_path = os.path.join(cls._tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), "Checkpoint not created" + + # Record phase-1 final weights for comparison + cls.phase1_weights = { + name: p.detach().cpu().clone() + for name, p in trainer.wrapper.named_parameters() + } + cls.lr_config = config["learning_rate"].copy() + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + +class TestDDPInitModel(_DDPResumeBase): + """DDP init_model: inherits weights but resets step to 0.""" + + def test_ddp_init_model(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_check_resume, + args=( + 2, + port, + self.data_dir, + self.ckpt_path, + 2, # numb_steps: train 2 fresh steps from step 0 + False, # is_restart=False → init_model + result_dict, + ), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, + ) + + # init_model resets step to 0 + self.assertEqual(r0["start_step"], 0) + self.assertEqual(r1["start_step"], 0) + + # LR should be lr_schedule(0), i.e. start_lr + lr_params = self.lr_config.copy() + lr_params["num_steps"] = 2 # init_model config uses numb_steps=2 + expected_lr = LearningRateExp(**lr_params).value(0) + self.assertAlmostEqual(r0["init_lr"], expected_lr, places=10) + + # Only rank 0 produces lcurve + self.assertTrue(r0["lcurve_exists"]) + self.assertFalse(r1["lcurve_exists"]) + + # Initial weights (after checkpoint load) must match phase-1 final weights + for name in self.phase1_weights: + self.assertIn(name, r0["init_weights"], f"Missing param: {name}") + torch.testing.assert_close( + r0["init_weights"][name], + self.phase1_weights[name], + msg=f"init_model did not inherit weights correctly: {name}", + ) + + # Initial weights identical across ranks + for name in r0["init_weights"]: + torch.testing.assert_close( + r0["init_weights"][name], + r1["init_weights"][name], + msg=f"init_model weights differ across ranks: {name}", + ) + + +class TestDDPRestart(_DDPResumeBase): + """DDP restart: continues training from saved step with restored optimizer.""" + + def test_ddp_restart(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_check_resume, + args=( + 2, + port, + self.data_dir, + self.ckpt_path, + 4, # numb_steps: continue to step 4 + True, # is_restart=True → restart_model + result_dict, + ), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, + ) + + # restart restores the step counter + self.assertEqual(r0["start_step"], 2) + self.assertEqual(r1["start_step"], 2) + + # LR should be lr_schedule(2) with num_steps=4 (the restart config) + lr_params = self.lr_config.copy() + lr_params["num_steps"] = 4 # restart config uses numb_steps=4 + lr_sched = LearningRateExp(**lr_params) + expected_lr = lr_sched.value(2) + start_lr = lr_sched.value(0) + self.assertAlmostEqual(r0["init_lr"], expected_lr, places=10) + # Verify it is NOT equal to start_lr (i.e. the LR actually decayed) + self.assertNotAlmostEqual( + r0["init_lr"], + start_lr, + places=10, + msg="restart LR should differ from start_lr", + ) + + # Only rank 0 produces lcurve + self.assertTrue(r0["lcurve_exists"]) + self.assertFalse(r1["lcurve_exists"]) + + # Initial weights (after checkpoint load) must match phase-1 final weights + for name in self.phase1_weights: + self.assertIn(name, r0["init_weights"], f"Missing param: {name}") + torch.testing.assert_close( + r0["init_weights"][name], + self.phase1_weights[name], + msg=f"restart did not load weights correctly: {name}", + ) + + # Initial weights identical across ranks + for name in r0["init_weights"]: + torch.testing.assert_close( + r0["init_weights"][name], + r1["init_weights"][name], + msg=f"restart weights differ across ranks: {name}", + ) + + +# --------------------------------------------------------------------------- +# Finetune + DDP tests +# --------------------------------------------------------------------------- + + +class _DDPFinetuneBase(unittest.TestCase): + """Shared setup: train pretrained se_e2_a model, save checkpoint + weights.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + cls._tmpdir = tempfile.mkdtemp(prefix="ddp_ft_setup_") + old_cwd = os.getcwd() + os.chdir(cls._tmpdir) + try: + config = _make_config(cls.data_dir, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + cls.ckpt_path = os.path.join(cls._tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), "Pretrained checkpoint not created" + + # Save pretrained state for comparison (excluding _extra_state) + state = torch.load(cls.ckpt_path, map_location="cpu", weights_only=True) + model_state = state["model"] if "model" in state else state + cls.pretrained_state = { + k: v.clone() for k, v in model_state.items() if k != "_extra_state" + } + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + +class TestDDPFinetune(_DDPFinetuneBase): + """DDP finetune: same type_map, descriptor + fitting from pretrained.""" + + def test_ddp_finetune(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + + config = _make_config(self.data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + mp.spawn( + _worker_finetune, + args=(2, port, self.ckpt_path, config, None, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Only rank 0 writes output + self.assertTrue(r0["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse(r1["lcurve_exists"], "rank 1 should NOT produce lcurve.out") + self.assertGreater(len(r0["ckpt_files"]), 0, "rank 0 should save checkpoints") + self.assertEqual(len(r1["ckpt_files"]), 0, "rank 1 should NOT save checkpoints") + + # Descriptor + fitting weights must match pretrained + for key in self.pretrained_state: + if key in r0["init_state"] and (".descriptor." in key or ".fitting" in key): + torch.testing.assert_close( + r0["init_state"][key], + self.pretrained_state[key], + msg=f"Weight should match pretrained: {key}", + ) + + # Init state identical across ranks + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Finetune init state differs across ranks: {name}", + ) + + +class TestDDPFinetuneRandomFitting(_DDPFinetuneBase): + """DDP finetune with random fitting: descriptor from pretrained, fitting random.""" + + def test_ddp_finetune_random_fitting(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + + config = _make_config(self.data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + mp.spawn( + _worker_finetune, + args=(2, port, self.ckpt_path, config, "RANDOM", result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Descriptor weights must match pretrained + for key in self.pretrained_state: + if key in r0["init_state"] and ".descriptor." in key: + torch.testing.assert_close( + r0["init_state"][key], + self.pretrained_state[key], + msg=f"Descriptor weight should match pretrained: {key}", + ) + + # Fitting weights should NOT match pretrained (random init) + # bias_atom_e is set by bias adjustment, not random init — skip it + has_fitting_diff = False + for key in self.pretrained_state: + if ( + key in r0["init_state"] + and ".fitting" in key + and "bias_atom_e" not in key + and r0["init_state"][key].is_floating_point() + ): + if not torch.equal(r0["init_state"][key], self.pretrained_state[key]): + has_fitting_diff = True + self.assertTrue( + has_fitting_diff, "Random fitting should produce different weights" + ) + + # Init state identical across ranks + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Finetune random fitting state differs across ranks: {name}", + ) + + +class TestDDPFinetuneNewType(unittest.TestCase): + """DDP finetune with type_map change (new type). + + Exercises the ``_unwrapped.model["Default"]`` path (line 712) when + ``finetune_rule.get_has_new_type()`` is True, plus stat recomputation + and broadcast for the new type. Uses DPA1 (mixed_types) which supports + ``change_type_map``. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + raw_data = os.path.join(data_dir, "data_0") + + # Subsample data for faster DPA1 test + cls._data_tmpdir = tempfile.mkdtemp(prefix="ddp_ft_nt_data_") + _subsample_data(raw_data, os.path.join(cls._data_tmpdir, "data_0")) + cls.data_dir = os.path.join(cls._data_tmpdir, "data_0") + + # Train pretrained DPA1 with type_map=["O", "H"] + cls._train_tmpdir = tempfile.mkdtemp(prefix="ddp_ft_nt_train_") + old_cwd = os.getcwd() + os.chdir(cls._train_tmpdir) + try: + config = _make_dpa1_config(cls.data_dir, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + cls.ckpt_path = os.path.join(cls._train_tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), ( + "DPA1 pretrained checkpoint not created" + ) + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._data_tmpdir, ignore_errors=True) + shutil.rmtree(cls._train_tmpdir, ignore_errors=True) + + def test_ddp_finetune_new_type(self) -> None: + """Finetune DPA1 from ["O","H"] to ["O","H","B"] under DDP.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + + # Finetune config with new type "B" added + config = _make_dpa1_config(self.data_dir, numb_steps=2) + config["model"]["type_map"] = ["O", "H", "B"] + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + mp.spawn( + _worker_finetune, + args=(2, port, self.ckpt_path, config, None, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Training completes without error + self.assertTrue(r0["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse(r1["lcurve_exists"], "rank 1 should NOT produce lcurve.out") + + # Init state identical across ranks (stat broadcast worked) + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Finetune new_type init state differs across ranks: {name}", + ) + + +def _make_dpa1_multitask_config( + data_dir: str, numb_steps: int = 2, type_map: list | None = None +) -> dict: + """Build a minimal multi-task DPA1 config (mixed_types) for finetune tests.""" + if type_map is None: + type_map = ["O", "H"] + descriptor = { + "type": "dpa1", + "sel": 12, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [4, 8], + "axis_neuron": 4, + "attn": 4, + "attn_layer": 1, + "attn_dotr": True, + "seed": 1, + } + fitting = { + "neuron": [8, 8], + "resnet_dt": True, + "seed": 1, + } + return { + "model": { + "shared_dict": { + "my_type_map": list(type_map), + "my_descriptor": deepcopy(descriptor), + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _worker_multitask_finetune( + rank, world_size, port, data_dir, ckpt_path, finetune_config, result_dict +): + """Worker: DDP multi-task finetune from checkpoint.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_ft_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = deepcopy(finetune_config) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + config["model"], finetune_links = get_finetune_rules( + ckpt_path, config["model"] + ) + trainer = get_trainer( + config, + finetune_model=ckpt_path, + finetune_links=finetune_links, + shared_links=shared_links, + ) + # Capture init state before training + init_state = { + k: v.detach().cpu().clone() + for k, v in trainer._unwrapped.state_dict().items() + if k != "_extra_state" + } + trainer.run() + result_dict[rank] = { + "init_state": init_state, + "lcurve_exists": os.path.exists("lcurve.out"), + "ckpt_files": [f for f in os.listdir(".") if f.endswith(".pt")], + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_single_task_compile_train(rank, world_size, port, data_dir, result_dict): + """Worker: run single-task DDP training with torch.compile enabled. + + This exercises the ``_compile_model`` code path under DDP, which must + unwrap ``DistributedDataParallel`` to access ``wrapper.module.model``. + Before the fix, ``self.wrapper.model[task_key]`` raised ``AttributeError`` + because ``DistributedDataParallel`` does not expose ``.model`` directly. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_compile_st_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=2) + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + # Check the compiled model is a _CompiledModel + is_compiled = isinstance( + trainer._unwrapped.model["Default"], _CompiledModel + ) + + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + weights = { + name: p.detach().cpu().clone() + for name, p in trainer._unwrapped.named_parameters() + } + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "weights": weights, + "is_compiled": is_compiled, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_multitask_compile_train(rank, world_size, port, data_dir, result_dict): + """Worker: run multi-task DDP training with torch.compile enabled. + + Exercises the per-branch compilation loop in ``_compile_model`` under DDP. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_compile_mt_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_multitask_config(data_dir, numb_steps=2) + config["training"]["enable_compile"] = True + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + # Check both branch models are compiled + compiled_flags = {} + for mk in ("model_1", "model_2"): + compiled_flags[mk] = isinstance( + trainer._unwrapped.model[mk], _CompiledModel + ) + + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + # Get shared descriptor params from model_1 + desc_params = {} + for name, p in trainer._unwrapped.model[ + "model_1" + ].original_model.atomic_model.descriptor.named_parameters(): + desc_params[name] = p.detach().cpu().clone() + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "desc_params": desc_params, + "compiled_flags": compiled_flags, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +class TestDDPCompileSingleTask(unittest.TestCase): + """DDP + torch.compile: single-task training with 2 ranks. + + Exercises ``_compile_model`` under DDP, which requires unwrapping + ``DistributedDataParallel`` to access ``wrapper.module.model``. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + def test_ddp_compile_single_task(self) -> None: + """2 ranks, se_e2_a, enable_compile=True, 2 steps.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_single_task_compile_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Both ranks have compiled models + self.assertTrue(results[0]["is_compiled"], "rank 0 model should be compiled") + self.assertTrue(results[1]["is_compiled"], "rank 1 model should be compiled") + + # Only rank 0 produces output files + self.assertTrue(results[0]["lcurve_exists"]) + self.assertFalse(results[1]["lcurve_exists"]) + self.assertGreater(results[0]["num_ckpts"], 0) + self.assertEqual(results[1]["num_ckpts"], 0) + + # Final weights identical across ranks + for name in results[0]["weights"]: + torch.testing.assert_close( + results[0]["weights"][name], + results[1]["weights"][name], + msg=f"Compiled DDP weights differ across ranks: {name}", + ) + + +class TestDDPCompileMultiTask(unittest.TestCase): + """DDP + torch.compile: multi-task training with 2 ranks. + + Exercises the per-branch compilation loop in ``_compile_model`` under DDP. + """ + + @classmethod + def setUpClass(cls) -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + cls.data_dir = _PT_DATA + + def test_ddp_compile_multitask(self) -> None: + """2 ranks, multi-task, enable_compile=True, 2 steps.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_compile_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Both ranks have compiled models for both branches + for mk in ("model_1", "model_2"): + self.assertTrue( + results[0]["compiled_flags"][mk], + f"rank 0 {mk} should be compiled", + ) + self.assertTrue( + results[1]["compiled_flags"][mk], + f"rank 1 {mk} should be compiled", + ) + + # Only rank 0 produces output files + self.assertTrue(results[0]["lcurve_exists"]) + self.assertFalse(results[1]["lcurve_exists"]) + self.assertGreater(results[0]["num_ckpts"], 0) + self.assertEqual(results[1]["num_ckpts"], 0) + + # Shared descriptor params identical across ranks + for name in results[0]["desc_params"]: + torch.testing.assert_close( + results[0]["desc_params"][name], + results[1]["desc_params"][name], + msg=f"Compiled DDP shared descriptor param differs: {name}", + ) + + +class TestDDPMultiTaskFinetune(unittest.TestCase): + """DDP multi-task finetune with type_map change (new type). + + Trains a 2-branch multi-task DPA1 model with type_map ["O","H"], then + finetunes 4 branches with extended type_map ["O","H","B"] under DDP. + Builds a reference state_dict by manually replicating the trainer's + finetune operations (load pretrained, change_type_map with computed + model_with_new_type_stat, weight copy) to verify correctness. + """ + + @classmethod + def setUpClass(cls) -> None: + from deepmd.pt_expt.model import ( + get_model, + ) + from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, + ) + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + from deepmd.pt_expt.utils.stat import ( + make_stat_input, + ) + from deepmd.utils.data import ( + DataRequirementItem, + ) + from deepmd.utils.data_system import ( + DeepmdDataSystem, + process_systems, + ) + + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + raw_data = os.path.join(data_dir, "data_0") + + # Subsample data for faster test + cls._data_tmpdir = tempfile.mkdtemp(prefix="ddp_mt_ft_data_") + _subsample_data(raw_data, os.path.join(cls._data_tmpdir, "data_0")) + cls.data_dir = os.path.join(cls._data_tmpdir, "data_0") + + ft_type_map = ["O", "H", "B"] + + # Train pretrained 2-branch multi-task DPA1 model + cls._train_tmpdir = tempfile.mkdtemp(prefix="ddp_mt_ft_train_") + old_cwd = os.getcwd() + os.chdir(cls._train_tmpdir) + try: + config = _make_dpa1_multitask_config(cls.data_dir, numb_steps=2) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + + cls.ckpt_path = os.path.join(cls._train_tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), ( + "DPA1 multi-task pretrained checkpoint not created" + ) + + # Build reference state_dict with extended type_map + state_dict_full = torch.load( + cls.ckpt_path, map_location=DEVICE, weights_only=True + ) + state_dict_ckpt = ( + state_dict_full["model"] + if "model" in state_dict_full + else state_dict_full + ) + pretrained_model_params = state_dict_ckpt["_extra_state"]["model_params"] + + pretrained_models = {} + for pk in pretrained_model_params["model_dict"]: + pretrained_models[pk] = get_model( + deepcopy(pretrained_model_params["model_dict"][pk]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_models) + pretrained_wrapper.load_state_dict(state_dict_ckpt) + + # Record pretrained state BEFORE change_type_map for O/H + # inheritance verification + cls.pretrained_oh_state = { + k: v.cpu().clone() + for k, v in pretrained_wrapper.model.state_dict().items() + } + + # Build model_with_new_type_stat with computed stats + ref_model_params = deepcopy( + pretrained_model_params["model_dict"]["model_1"] + ) + ref_model_params["type_map"] = ft_type_map + ref_model = get_model(ref_model_params).to(DEVICE) + + energy_data_req = [ + DataRequirementItem( + "energy", ndof=1, atomic=False, must=False, high_prec=True + ), + DataRequirementItem( + "force", ndof=3, atomic=True, must=False, high_prec=False + ), + DataRequirementItem( + "virial", ndof=9, atomic=False, must=False, high_prec=False + ), + ] + data_systems = process_systems([cls.data_dir]) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=ft_type_map, + trn_all_set=True, + ) + data.add_data_requirements(energy_data_req) + ref_model.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + + for pk in pretrained_model_params["model_dict"]: + pretrained_wrapper.model[pk].change_type_map( + ft_type_map, + model_with_new_type_stat=ref_model, + ) + + cls.ref_state_dict = pretrained_wrapper.model.state_dict() + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._data_tmpdir, ignore_errors=True) + shutil.rmtree(cls._train_tmpdir, ignore_errors=True) + + def test_ddp_multitask_finetune(self) -> None: + """Finetune 4-branch DPA1 from 2-branch with extended type_map under DDP.""" + ft_type_map = ["O", "H", "B"] + ft_config = _make_dpa1_multitask_config( + self.data_dir, numb_steps=1, type_map=ft_type_map + ) + + # Add model_3 and model_4 + ft_config["model"]["model_dict"]["model_3"] = deepcopy( + ft_config["model"]["model_dict"]["model_2"] + ) + ft_config["model"]["model_dict"]["model_4"] = deepcopy( + ft_config["model"]["model_dict"]["model_2"] + ) + ft_config["loss_dict"]["model_3"] = deepcopy(ft_config["loss_dict"]["model_2"]) + ft_config["loss_dict"]["model_4"] = deepcopy(ft_config["loss_dict"]["model_2"]) + ft_config["training"]["model_prob"]["model_3"] = 0.25 + ft_config["training"]["model_prob"]["model_4"] = 0.25 + ft_config["training"]["model_prob"]["model_1"] = 0.25 + ft_config["training"]["model_prob"]["model_2"] = 0.25 + ft_config["training"]["data_dict"]["model_3"] = deepcopy( + ft_config["training"]["data_dict"]["model_2"] + ) + ft_config["training"]["data_dict"]["model_3"]["stat_file"] = ( + "./stat_files/model_3" + ) + ft_config["training"]["data_dict"]["model_4"] = deepcopy( + ft_config["training"]["data_dict"]["model_2"] + ) + ft_config["training"]["data_dict"]["model_4"]["stat_file"] = ( + "./stat_files/model_4" + ) + + # Finetune rules: + # model_1: no finetune_head → resume + # model_2: finetune from model_2 + ft_config["model"]["model_dict"]["model_2"]["finetune_head"] = "model_2" + # model_3: finetune from model_2 + ft_config["model"]["model_dict"]["model_3"]["finetune_head"] = "model_2" + # model_4: no finetune_head, new key → random fitting + + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_finetune, + args=( + 2, + port, + self.data_dir, + self.ckpt_path, + ft_config, + result_dict, + ), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Only rank 0 writes output + self.assertTrue(r0["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse(r1["lcurve_exists"], "rank 1 should NOT produce lcurve.out") + + # Init state identical across ranks (DDP sync for finetune) + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Multi-task finetune init state differs across ranks: {name}", + ) + + # Verify weight inheritance against reference (with extended type_map) + # Keys in init_state have "model." prefix from wrapper; ref_state_dict + # is from wrapper.model.state_dict() so keys don't have "model." prefix + ref = self.ref_state_dict + init = r0["init_state"] + for key in init: + # Skip type_embedding (random init for new type B differs) + if "type_embedding" in key: + continue + # Strip "model." prefix + model_key = key.split("model.", 1)[-1] if key.startswith("model.") else key + if "model_1" in key: + # model_1: resume — ALL weights match reference + if model_key in ref: + torch.testing.assert_close( + ref[model_key], + init[key], + msg=f"model_1 (resume) DDP mismatch: {key}", + ) + elif "model_2" in key and "out_bias" not in key and "out_std" not in key: + if model_key in ref: + torch.testing.assert_close( + ref[model_key], + init[key], + msg=f"model_2 (finetune) DDP mismatch: {key}", + ) + elif "model_3" in key and "out_bias" not in key and "out_std" not in key: + ref_key = model_key.replace("model_3", "model_2") + if ref_key in ref: + torch.testing.assert_close( + ref[ref_key], + init[key], + msg=f"model_3 (from model_2) DDP mismatch: {key}", + ) + elif ( + "model_4" in key + and "fitting_net" not in key + and "out_bias" not in key + and "out_std" not in key + ): + ref_key = model_key.replace("model_4", "model_2") + if ref_key in ref: + torch.testing.assert_close( + ref[ref_key], + init[key], + msg=f"model_4 (random) descriptor DDP mismatch: {key}", + ) + + # Verify O/H descriptor stats are inherited from pretrained (not + # recomputed). pretrained_oh_state has shape [2,...] for O,H; + # finetuned init has shape [3,...] for O,H,B. + _STAT_SUFFIXES = ("mean", "stddev", "davg", "dstd") + n_old = 2 # ["O", "H"] + n_new = 3 # ["O", "H", "B"] + checked_count = 0 + pretrained_oh = self.pretrained_oh_state + for key in init: + if "type_embedding" in key: + continue + if not any(key.endswith(s) for s in _STAT_SUFFIXES): + continue + # Use model_1 (all branches share descriptor after share_params) + if "model_1" not in key: + continue + # init_state has "model." prefix; pretrained_oh_state doesn't + pre_key = key.split("model.", 1)[-1] if key.startswith("model.") else key + if pre_key not in pretrained_oh: + continue + pre_val = pretrained_oh[pre_key] + ft_val = init[key] + # Find the type axis (size grew from n_old to n_new) + for ax in range(pre_val.ndim): + if pre_val.shape[ax] == n_old and ft_val.shape[ax] == n_new: + for ti, tname in enumerate(["O", "H"]): + torch.testing.assert_close( + ft_val.select(ax, ti), + pre_val.select(ax, ti), + msg=(f"{tname} stat not inherited from pretrained: {key}"), + ) + checked_count += 1 + break + self.assertGreater( + checked_count, + 0, + "No descriptor stat keys found for O/H inheritance check", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/utils/test_activation.py b/source/tests/pt_expt/utils/test_activation.py new file mode 100644 index 0000000000..23550d3315 --- /dev/null +++ b/source/tests/pt_expt/utils/test_activation.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) +from deepmd.pt_expt.utils.network import ( + _torch_activation, +) + + +class TestSilutActivation: + """Tests for silut activation in _torch_activation.""" + + def setup_method(self) -> None: + # x values spanning both branches: below threshold and above + self.x_np = np.array( + [-5.0, -1.0, 0.0, 1.0, 2.5, 3.0, 5.0, 10.0, 15.0, 20.0], + dtype=np.float64, + ) + self.x_torch = torch.tensor(self.x_np, dtype=torch.float64) + + def test_silut_with_threshold(self) -> None: + """silut:10.0 matches dpmodel numerically.""" + result = _torch_activation(self.x_torch, "silut:10.0") + dp_fn = get_activation_fn("silut:10.0") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_default_threshold(self) -> None: + """Silut without parameter uses default threshold 3.0.""" + result = _torch_activation(self.x_torch, "silut") + dp_fn = get_activation_fn("silut") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_custom_silu_alias(self) -> None: + """custom_silu:5.0 is an alias for silut:5.0.""" + result = _torch_activation(self.x_torch, "custom_silu:5.0") + dp_fn = get_activation_fn("custom_silu:5.0") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_gradient(self) -> None: + """Gradient flows through both branches of silut.""" + x = self.x_torch.clone().requires_grad_(True) + y = _torch_activation(x, "silut:3.0") + loss = y.sum() + loss.backward() + grad = x.grad + assert grad is not None + # gradient should be finite everywhere + assert torch.all(torch.isfinite(grad)) + # gradient should be non-zero for non-zero inputs + nonzero_mask = self.x_np != 0.0 + assert torch.all(grad[nonzero_mask] != 0.0) + + def test_silut_make_fx(self) -> None: + """make_fx can trace through silut activation.""" + + def fn(x: torch.Tensor) -> torch.Tensor: + return _torch_activation(x, "silut:10.0") + + traced = make_fx(fn)(self.x_torch) + result = traced(self.x_torch) + expected = _torch_activation(self.x_torch, "silut:10.0") + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12 + ) + + def test_silut_below_threshold_is_silu(self) -> None: + """Below threshold, silut equals silu exactly.""" + x_below = torch.tensor([-5.0, 0.0, 1.0, 5.0, 9.9], dtype=torch.float64) + result = _torch_activation(x_below, "silut:10.0") + silu = x_below * torch.sigmoid(x_below) + np.testing.assert_allclose( + result.detach().numpy(), silu.detach().numpy(), rtol=1e-14, atol=1e-14 + ) + + def test_silut_above_threshold_is_tanh_branch(self) -> None: + """Above threshold, silut equals tanh(slope*(x-T))+const.""" + import math + + threshold = 3.0 + sig_t = 1.0 / (1.0 + math.exp(-threshold)) + slope = sig_t + threshold * sig_t * (1.0 - sig_t) + const = threshold * sig_t + + x_above = torch.tensor([3.5, 5.0, 10.0, 20.0], dtype=torch.float64) + result = _torch_activation(x_above, "silut:3.0") + expected = torch.tanh(slope * (x_above - threshold)) + const + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-14, atol=1e-14 + ) + + def test_silut_export(self) -> None: + """torch.export.export can trace through silut activation.""" + + class SilutModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _torch_activation(x, "silut:10.0") + + mod = SilutModule() + exported = torch.export.export(mod, (self.x_torch,)) + result = exported.module()(self.x_torch) + expected = _torch_activation(self.x_torch, "silut:10.0") + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12 + )