From bca1cd71256331e2fd3da616de81b3f387781954 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 25 May 2026 18:02:45 +0800 Subject: [PATCH 01/14] fix: try compile only once --- deepmd/pt_expt/infer/deep_eval.py | 5 + deepmd/pt_expt/train/training.py | 218 ++++++++++++++++++++++-------- 2 files changed, 165 insertions(+), 58 deletions(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 8f40600ffc..60c1e72e4f 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -366,10 +366,15 @@ def _load_pt(self, model_file: str, head: str | None = None) -> None: # eager inference). Drop the latter and unwrap the former. cleaned: dict[str, Any] = {} compiled_marker = ".compiled_forward_lower." + # Per-task buffer copies registered on _CompiledModel (bias_atom_e, + # case_embd) — real values live on the original model's fitting net. + task_buf_marker = "._task_" wrapper_infix = ".original_model." for key, value in state_dict.items(): if compiled_marker in key: continue + if task_buf_marker in key: + continue if wrapper_infix in key: key = key.replace(wrapper_infix, ".", 1) cleaned[key] = value diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1059af0be6..579c1bb63a 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -70,6 +70,42 @@ log = logging.getLogger(__name__) +# Buffer names that differ per task after share_params; everything else in the +# fitting net is literally the same Python object across shared tasks. +_TASK_SPECIFIC_BUFFER_NAMES: tuple[str, ...] = ("bias_atom_e", "case_embd") + + +def _get_task_buffers(model: torch.nn.Module) -> dict[str, torch.Tensor]: + """Return per-task fitting-net buffers that vary across shared tasks.""" + try: + fitting = model.get_fitting_net() + except AttributeError: + return {} + result: dict[str, torch.Tensor] = {} + for name in _TASK_SPECIFIC_BUFFER_NAMES: + val = getattr(fitting, name, None) + if val is not None and torch.is_tensor(val): + result[name] = val.detach().clone() + return result + + +def _get_model_structure_key(model: torch.nn.Module) -> int: + """Return an id that is identical for all tasks that share a fitting net. + + After ``share_params``, the fitting net's child sub-modules are literally + the same Python objects across tasks. The first non-task-specific child's + ``id()`` is therefore the same for all shared tasks and unique across + unrelated models. + """ + try: + fitting = model.get_fitting_net() + for name, child in fitting.named_children(): + if name not in _TASK_SPECIFIC_BUFFER_NAMES: + return id(child) + except AttributeError: + pass + return id(model) + # --------------------------------------------------------------------------- # Helper: loss factory (reused from pt) @@ -214,7 +250,8 @@ def _trace_and_compile( aparam: torch.Tensor | None, compile_opts: dict[str, Any] | None = None, charge_spin: torch.Tensor | None = None, -) -> torch.nn.Module: + task_buffers: dict[str, torch.Tensor] | None = None, +) -> tuple[torch.nn.Module, tuple[str, ...]]: """Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True. Parameters @@ -226,11 +263,17 @@ def _trace_and_compile( compile_opts : dict or None User-supplied inductor options. These are merged on top of the built-in defaults (user values take precedence). + task_buffers : dict or None + Per-task fitting-net buffers (``bias_atom_e``, ``case_embd``) to + promote to explicit FX ``placeholder`` nodes so the compiled graph is + reusable across tasks that share the same structure. Returns ------- - torch.nn.Module + compiled : torch.nn.Module The compiled ``forward_lower`` callable. + task_buf_order : tuple[str, ...] + Ordered names of the promoted buffers (empty when none). """ from torch.fx.experimental.proxy_tensor import ( make_fx, @@ -244,6 +287,19 @@ def _trace_and_compile( # backprop cannot reach the weights and force RMSE never decreases. model.train() + task_buf_order: tuple[str, ...] = tuple(task_buffers.keys()) if task_buffers else () + task_buf_vals_trace: tuple[torch.Tensor, ...] = ( + tuple(task_buffers[k] for k in task_buf_order) if task_buffers else () + ) + + # Resolve fitting net once for buffer patching inside fn. + _fitting: torch.nn.Module | None = None + if task_buf_order: + try: + _fitting = model.get_fitting_net() + except AttributeError: + pass + def fn( extended_coord: torch.Tensor, extended_atype: torch.Tensor, @@ -252,17 +308,30 @@ def fn( fparam: torch.Tensor | None, aparam: torch.Tensor | None, charge_spin: torch.Tensor | None, + *task_buf_vals: torch.Tensor, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) - return model.forward_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fparam, - aparam=aparam, - charge_spin=charge_spin, - ) + # Temporarily patch task-specific buffers with the proxy tensors so + # make_fx records them as FX placeholders rather than baked-in constants. + # This makes the compiled graph reusable for any buffer values. + originals: dict[str, torch.Tensor | None] = {} + if _fitting is not None and task_buf_order: + for name, val in zip(task_buf_order, task_buf_vals): + originals[name] = _fitting._buffers.get(name) + _fitting._buffers[name] = val + try: + return model.forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + finally: + for name, orig in originals.items(): + _fitting._buffers[name] = orig # Pick a trace-time nframes that's unlikely to collide with any other # tensor dim in the graph. The symbolic tracer merges symbols that @@ -309,7 +378,7 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: tracing_mode="symbolic", _allow_non_fake_inputs=True, decomposition_table=decomp_table, - )(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin) + )(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin, *task_buf_vals_trace) # make_fx inserts aten.detach.default for saved tensors used in the # decomposed autograd.grad backward ops. These detach nodes break @@ -344,7 +413,7 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: backend="inductor", dynamic=True, options=inductor_options, - ) + ), task_buf_order class _CompiledModel(torch.nn.Module): @@ -354,10 +423,17 @@ def __init__( self, original_model: torch.nn.Module, compiled_forward_lower: torch.nn.Module, + task_buf_order: tuple[str, ...] = (), + task_buffers: dict[str, torch.Tensor] | None = None, ) -> None: super().__init__() self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower + self._task_buf_order = task_buf_order + if task_buf_order and task_buffers: + for name in task_buf_order: + if name in task_buffers: + self.register_buffer(f"_task_{name}", task_buffers[name]) def forward( self, @@ -404,8 +480,12 @@ def forward( ext_coord = ext_coord.reshape(nframes, -1, 3) ext_coord = ext_coord.detach().requires_grad_(True) + task_buf_vals = tuple( + getattr(self, f"_task_{name}") for name in self._task_buf_order + ) result = self.compiled_forward_lower( - ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin + ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin, + *task_buf_vals, ) # Translate forward_lower keys -> forward keys. @@ -947,6 +1027,13 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: else self.wrapper ) + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP + + # structure_key -> (compiled_lower, task_buf_order) + # Shared-fitting tasks produce the same structure key so only the first + # task triggers make_fx + torch.compile; the rest reuse the result. + _compiled_by_structure: dict[int, tuple] = {} + for task_key in self.model_keys: model = wrapper_mod.model[task_key] @@ -957,8 +1044,6 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: # is hardware-dependent. Warn but do not reject — energies # remain well within training tolerance and the user may # accept the trade-off for compile speed. - from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP - descriptor = model.get_descriptor() if isinstance(descriptor, DescrptDPA1DP): n_attn = descriptor.get_numb_attn_layer() @@ -974,54 +1059,71 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: task_key, ) - inp, _ = self.get_data(is_train=True, task_key=task_key) - coord = inp["coord"].detach() - atype = inp["atype"].detach() - box = inp.get("box") - if box is not None: - box = box.detach() + structure_key = _get_model_structure_key(model) + task_bufs = _get_task_buffers(model) - nframes, nloc = atype.shape[:2] - coord_3d = coord.reshape(nframes, nloc, 3) - box_flat = box.reshape(nframes, 9) if box is not None else None - - if box_flat is not None: - coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3)) + if structure_key in _compiled_by_structure: + # Shared structure: reuse the already-compiled graph. + compiled_lower, task_buf_order = _compiled_by_structure[structure_key] + log.info( + "Reusing compiled graph for task=%s (shared model structure).", + task_key, + ) else: - coord_norm = coord_3d - - ext_coord, ext_atype, mapping = extend_coord_with_ghosts( - coord_norm, atype, box_flat, model.get_rcut() - ) - nlist_t = build_neighbor_list( - ext_coord, - ext_atype, - nloc, - model.get_rcut(), - model.get_sel(), - distinguish_types=False, - ) - ext_coord = ext_coord.reshape(nframes, -1, 3) + inp, _ = self.get_data(is_train=True, task_key=task_key) + coord = inp["coord"].detach() + atype = inp["atype"].detach() + box = inp.get("box") + if box is not None: + box = box.detach() + + nframes, nloc = atype.shape[:2] + coord_3d = coord.reshape(nframes, nloc, 3) + box_flat = box.reshape(nframes, 9) if box is not None else None + + if box_flat is not None: + coord_norm = normalize_coord( + coord_3d, box_flat.reshape(nframes, 3, 3) + ) + else: + coord_norm = coord_3d - fparam = inp.get("fparam") - aparam = inp.get("aparam") - charge_spin = inp.get("charge_spin") + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype, box_flat, model.get_rcut() + ) + nlist_t = build_neighbor_list( + ext_coord, + ext_atype, + nloc, + model.get_rcut(), + model.get_sel(), + distinguish_types=False, + ) + ext_coord = ext_coord.reshape(nframes, -1, 3) + + fparam = inp.get("fparam") + aparam = inp.get("aparam") + charge_spin = inp.get("charge_spin") + + compiled_lower, task_buf_order = _trace_and_compile( + model, + ext_coord, + ext_atype, + nlist_t, + mapping, + fparam, + aparam, + charge_spin=charge_spin, + task_buffers=task_bufs if task_bufs else None, + compile_opts=compile_opts, + ) + _compiled_by_structure[structure_key] = (compiled_lower, task_buf_order) - compiled_lower = _trace_and_compile( - model, - ext_coord, - ext_atype, - nlist_t, - mapping, - fparam, - aparam, - charge_spin=charge_spin, - compile_opts=compile_opts, + wrapper_mod.model[task_key] = _CompiledModel( + model, compiled_lower, task_buf_order, task_bufs ) - - wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower) log.info( - "Model compiled (task=%s, tracing_mode=symbolic, " + "Model compiled/reused (task=%s, tracing_mode=symbolic, " "dynamic=True, backend=inductor).", task_key, ) From 07ce0257d2c294608585f64484b89a156fca20dc Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 26 May 2026 10:01:16 +0800 Subject: [PATCH 02/14] fix: detach graph in log --- deepmd/pt_expt/train/training.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 579c1bb63a..c8ee26c8d1 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -1312,9 +1312,14 @@ def run(self) -> None: self.wrapper.eval() if self.rank == 0: + def _to_float(v: Any) -> float: + return v.detach().item() if torch.is_tensor(v) else float(v) + if not self.multi_task: train_results = { - k: v for k, v in more_loss.items() if "l2_" not in k + k: _to_float(v) + for k, v in more_loss.items() + if "l2_" not in k } # validation @@ -1335,7 +1340,8 @@ def run(self) -> None: for k, v in _vmore.items(): if "l2_" not in k: valid_results[k] = ( - valid_results.get(k, 0.0) + v * natoms + valid_results.get(k, 0.0) + + _to_float(v) * natoms ) if sum_natoms > 0: valid_results = { @@ -1348,7 +1354,9 @@ def run(self) -> None: # current task already has loss train_results[task_key] = { - k: v for k, v in more_loss.items() if "l2_" not in k + k: _to_float(v) + for k, v in more_loss.items() + if "l2_" not in k } # compute loss for other tasks @@ -1363,7 +1371,9 @@ def run(self) -> None: task_key=_key, ) train_results[_key] = { - k: v for k, v in _more.items() if "l2_" not in k + k: _to_float(v) + for k, v in _more.items() + if "l2_" not in k } # validation for each task @@ -1387,7 +1397,10 @@ def run(self) -> None: _sum_natoms += natoms for k, v in _vmore.items(): if "l2_" not in k: - _vres[k] = _vres.get(k, 0.0) + v * natoms + _vres[k] = ( + _vres.get(k, 0.0) + + _to_float(v) * natoms + ) if _sum_natoms > 0: _vres = { k: v / _sum_natoms for k, v in _vres.items() From 4cee0bf95e8562ce1ebdc3a53a853c2f676ca17f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 May 2026 02:05:09 +0000 Subject: [PATCH 03/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt_expt/train/training.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index c8ee26c8d1..b4fa41c948 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -378,7 +378,16 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: tracing_mode="symbolic", _allow_non_fake_inputs=True, decomposition_table=decomp_table, - )(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin, *task_buf_vals_trace) + )( + ext_coord, + ext_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, + *task_buf_vals_trace, + ) # make_fx inserts aten.detach.default for saved tensors used in the # decomposed autograd.grad backward ops. These detach nodes break @@ -484,7 +493,13 @@ def forward( getattr(self, f"_task_{name}") for name in self._task_buf_order ) result = self.compiled_forward_lower( - ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin, + ext_coord, + ext_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, *task_buf_vals, ) @@ -1312,6 +1327,7 @@ def run(self) -> None: self.wrapper.eval() if self.rank == 0: + def _to_float(v: Any) -> float: return v.detach().item() if torch.is_tensor(v) else float(v) From f3e29fe5cc47fdaba61c2608a69a6aa4280a04c8 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 26 May 2026 11:12:30 +0800 Subject: [PATCH 04/14] fix: buffer register --- deepmd/pt_expt/train/training.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index b4fa41c948..3489c94c16 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -439,10 +439,9 @@ def __init__( self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower self._task_buf_order = task_buf_order - if task_buf_order and task_buffers: - for name in task_buf_order: - if name in task_buffers: - self.register_buffer(f"_task_{name}", task_buffers[name]) + # task_buffers is intentionally not stored: buffers are read from + # original_model.get_fitting_net() at forward time so that weight + # updates (load_state_dict, optimiser steps) are always reflected. def forward( self, @@ -489,9 +488,16 @@ def forward( ext_coord = ext_coord.reshape(nframes, -1, 3) ext_coord = ext_coord.detach().requires_grad_(True) - task_buf_vals = tuple( - getattr(self, f"_task_{name}") for name in self._task_buf_order - ) + if self._task_buf_order: + try: + _fitting = self.original_model.get_fitting_net() + task_buf_vals: tuple = tuple( + getattr(_fitting, name) for name in self._task_buf_order + ) + except AttributeError: + task_buf_vals = () + else: + task_buf_vals = () result = self.compiled_forward_lower( ext_coord, ext_atype, From 8e9c2a5bcbbfd2e8eda7be6c4ab15a4a26fc22de Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 26 May 2026 17:19:09 +0800 Subject: [PATCH 05/14] fix: promote out_bias --- deepmd/pt_expt/train/training.py | 82 ++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 19 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 3489c94c16..88d07615d9 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -70,22 +70,41 @@ log = logging.getLogger(__name__) -# Buffer names that differ per task after share_params; everything else in the -# fitting net is literally the same Python object across shared tasks. +# Buffer names in the fitting net that differ per task after share_params; +# everything else in the fitting net is the same Python object across tasks. _TASK_SPECIFIC_BUFFER_NAMES: tuple[str, ...] = ("bias_atom_e", "case_embd") +# Buffer names in atomic_model that are per-task (energy/output statistics). +# These live one level above the fitting net and are not reached by +# fitting-net share_params, so they must also be promoted to FX placeholders. +_ATOMIC_MODEL_TASK_BUFFER_NAMES: tuple[str, ...] = ("out_bias", "out_std") + +# Prefix used in task_buf_order keys to distinguish atomic_model buffers +# from fitting-net buffers. +_AM_PREFIX = "am/" + def _get_task_buffers(model: torch.nn.Module) -> dict[str, torch.Tensor]: - """Return per-task fitting-net buffers that vary across shared tasks.""" + """Return per-task buffers (fitting net + atomic model) that vary across shared tasks.""" + result: dict[str, torch.Tensor] = {} + # fitting-net task buffers try: fitting = model.get_fitting_net() + for name in _TASK_SPECIFIC_BUFFER_NAMES: + val = fitting._buffers.get(name) + if val is not None and torch.is_tensor(val): + result[name] = val.detach().clone() except AttributeError: - return {} - result: dict[str, torch.Tensor] = {} - for name in _TASK_SPECIFIC_BUFFER_NAMES: - val = getattr(fitting, name, None) - if val is not None and torch.is_tensor(val): - result[name] = val.detach().clone() + pass + # atomic_model task buffers (out_bias, out_std) + try: + am = model.atomic_model + for name in _ATOMIC_MODEL_TASK_BUFFER_NAMES: + val = am._buffers.get(name) + if val is not None and torch.is_tensor(val): + result[_AM_PREFIX + name] = val.detach().clone() + except AttributeError: + pass return result @@ -292,13 +311,18 @@ def _trace_and_compile( tuple(task_buffers[k] for k in task_buf_order) if task_buffers else () ) - # Resolve fitting net once for buffer patching inside fn. + # Resolve fitting net and atomic_model once for buffer patching inside fn. _fitting: torch.nn.Module | None = None + _atomic_model: torch.nn.Module | None = None if task_buf_order: try: _fitting = model.get_fitting_net() except AttributeError: - pass + pass # no fitting net → no fitting-net buffers to patch + try: + _atomic_model = model.atomic_model + except AttributeError: + pass # no atomic_model → no atomic-model buffers to patch def fn( extended_coord: torch.Tensor, @@ -313,12 +337,20 @@ def fn( extended_coord = extended_coord.detach().requires_grad_(True) # Temporarily patch task-specific buffers with the proxy tensors so # make_fx records them as FX placeholders rather than baked-in constants. - # This makes the compiled graph reusable for any buffer values. + # Keys prefixed with _AM_PREFIX are atomic_model buffers; the rest are + # fitting-net buffers. originals: dict[str, torch.Tensor | None] = {} - if _fitting is not None and task_buf_order: + if task_buf_order: for name, val in zip(task_buf_order, task_buf_vals): - originals[name] = _fitting._buffers.get(name) - _fitting._buffers[name] = val + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX):] + if _atomic_model is not None: + originals[name] = _atomic_model._buffers.get(actual) + _atomic_model._buffers[actual] = val + else: + if _fitting is not None: + originals[name] = _fitting._buffers.get(name) + _fitting._buffers[name] = val try: return model.forward_lower( extended_coord, @@ -331,7 +363,13 @@ def fn( ) finally: for name, orig in originals.items(): - _fitting._buffers[name] = orig + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX):] + if _atomic_model is not None: + _atomic_model._buffers[actual] = orig + else: + if _fitting is not None: + _fitting._buffers[name] = orig # Pick a trace-time nframes that's unlikely to collide with any other # tensor dim in the graph. The symbolic tracer merges symbols that @@ -491,9 +529,15 @@ def forward( if self._task_buf_order: try: _fitting = self.original_model.get_fitting_net() - task_buf_vals: tuple = tuple( - getattr(_fitting, name) for name in self._task_buf_order - ) + _am = getattr(self.original_model, "atomic_model", None) + _vals: list[torch.Tensor] = [] + for _name in self._task_buf_order: + if _name.startswith(_AM_PREFIX): + _actual = _name[len(_AM_PREFIX):] + _vals.append(_am._buffers[_actual]) + else: + _vals.append(getattr(_fitting, _name)) + task_buf_vals: tuple = tuple(_vals) except AttributeError: task_buf_vals = () else: From 9ce8d3e135e7f9aa590a0c9d090a9cf301a49838 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 May 2026 09:20:28 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt_expt/train/training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 88d07615d9..e48420f8fb 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -343,7 +343,7 @@ def fn( if task_buf_order: for name, val in zip(task_buf_order, task_buf_vals): if name.startswith(_AM_PREFIX): - actual = name[len(_AM_PREFIX):] + actual = name[len(_AM_PREFIX) :] if _atomic_model is not None: originals[name] = _atomic_model._buffers.get(actual) _atomic_model._buffers[actual] = val @@ -364,7 +364,7 @@ def fn( finally: for name, orig in originals.items(): if name.startswith(_AM_PREFIX): - actual = name[len(_AM_PREFIX):] + actual = name[len(_AM_PREFIX) :] if _atomic_model is not None: _atomic_model._buffers[actual] = orig else: @@ -533,7 +533,7 @@ def forward( _vals: list[torch.Tensor] = [] for _name in self._task_buf_order: if _name.startswith(_AM_PREFIX): - _actual = _name[len(_AM_PREFIX):] + _actual = _name[len(_AM_PREFIX) :] _vals.append(_am._buffers[_actual]) else: _vals.append(getattr(_fitting, _name)) From 8c25a2289f8d4cbe96a03e4123cb600889d65997 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 28 May 2026 10:27:40 +0800 Subject: [PATCH 07/14] fix: ensure distinct compiled graphs for shared fitting with different descriptors --- deepmd/pt_expt/train/training.py | 117 +++++++++---- source/tests/pt_expt/test_training.py | 228 ++++++++++++++++++++++++++ 2 files changed, 314 insertions(+), 31 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index e48420f8fb..97f4d75fd5 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -70,13 +70,11 @@ log = logging.getLogger(__name__) -# Buffer names in the fitting net that differ per task after share_params; -# everything else in the fitting net is the same Python object across tasks. -_TASK_SPECIFIC_BUFFER_NAMES: tuple[str, ...] = ("bias_atom_e", "case_embd") - # Buffer names in atomic_model that are per-task (energy/output statistics). # These live one level above the fitting net and are not reached by -# fitting-net share_params, so they must also be promoted to FX placeholders. +# fitting-net share_params. They are always promoted to FX placeholders +# because model_change_out_bias may replace them out-of-place after +# compilation, so the compiled forward must read them fresh each call. _ATOMIC_MODEL_TASK_BUFFER_NAMES: tuple[str, ...] = ("out_bias", "out_std") # Prefix used in task_buf_order keys to distinguish atomic_model buffers @@ -84,19 +82,43 @@ _AM_PREFIX = "am/" -def _get_task_buffers(model: torch.nn.Module) -> dict[str, torch.Tensor]: - """Return per-task buffers (fitting net + atomic model) that vary across shared tasks.""" +def _detect_task_buffers( + model: torch.nn.Module, + group_models: list["torch.nn.Module"], +) -> dict[str, torch.Tensor]: + """Collect per-task buffers to promote to FX placeholders. + + Fitting-net buffers are auto-detected by identity diff across + *group_models* (all tasks that share this model's structure key after + ``share_params``). Any buffer that is a *different* Python object in at + least one other group member is task-specific and gets promoted. + + Atomic-model buffers listed in ``_ATOMIC_MODEL_TASK_BUFFER_NAMES`` are + always promoted because ``model_change_out_bias`` may replace them + out-of-place after compilation. + """ result: dict[str, torch.Tensor] = {} - # fitting-net task buffers + + # Auto-detect fitting-net task buffers by identity diff across the group. try: fitting = model.get_fitting_net() - for name in _TASK_SPECIFIC_BUFFER_NAMES: - val = fitting._buffers.get(name) - if val is not None and torch.is_tensor(val): - result[name] = val.detach().clone() + for name, val in fitting._buffers.items(): + if val is None or not torch.is_tensor(val): + continue + for other in group_models: + if other is model: + continue + try: + other_val = other.get_fitting_net()._buffers.get(name) + if other_val is not val: + result[name] = val.detach().clone() + break + except AttributeError: + pass except AttributeError: pass - # atomic_model task buffers (out_bias, out_std) + + # Atomic-model task buffers (always promote). try: am = model.atomic_model for name in _ATOMIC_MODEL_TASK_BUFFER_NAMES: @@ -105,25 +127,35 @@ def _get_task_buffers(model: torch.nn.Module) -> dict[str, torch.Tensor]: result[_AM_PREFIX + name] = val.detach().clone() except AttributeError: pass + return result -def _get_model_structure_key(model: torch.nn.Module) -> int: - """Return an id that is identical for all tasks that share a fitting net. +def _get_model_structure_key(model: torch.nn.Module) -> tuple[int, ...]: + """Return a key that is identical iff two tasks can safely share a compiled graph. + + The key captures both the descriptor identity and the fitting-net + structure so that tasks sharing a fitting net but using *different* + descriptors (which bake distinct descriptor constants into the traced + graph) are never assigned the same compiled graph. - After ``share_params``, the fitting net's child sub-modules are literally - the same Python objects across tasks. The first non-task-specific child's - ``id()`` is therefore the same for all shared tasks and unique across - unrelated models. + After ``share_params``, the fitting net's child sub-modules are the same + Python objects across tasks, so ``id(first_child)`` is equal for all + shared tasks and unique across unrelated models. """ + descriptor_id: int = 0 + try: + descriptor_id = id(model.get_descriptor()) + except AttributeError: + pass + try: fitting = model.get_fitting_net() - for name, child in fitting.named_children(): - if name not in _TASK_SPECIFIC_BUFFER_NAMES: - return id(child) + for _, child in fitting.named_children(): + return (descriptor_id, id(child)) except AttributeError: pass - return id(model) + return (descriptor_id, id(model)) # --------------------------------------------------------------------------- @@ -283,9 +315,10 @@ def _trace_and_compile( User-supplied inductor options. These are merged on top of the built-in defaults (user values take precedence). task_buffers : dict or None - Per-task fitting-net buffers (``bias_atom_e``, ``case_embd``) to - promote to explicit FX ``placeholder`` nodes so the compiled graph is - reusable across tasks that share the same structure. + Per-task buffers (e.g. ``bias_atom_e``, ``case_embd``, ``out_bias``, + ``out_std``) detected by ``_detect_task_buffers``. These are promoted + to explicit FX ``placeholder`` nodes so the compiled graph is reusable + across tasks that share the same structure key. Returns ------- @@ -1092,12 +1125,34 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: else self.wrapper ) + from collections import ( + defaultdict, + ) + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP + # Pre-pass: group tasks by structure key and auto-detect per-task buffers. + # Grouping is needed so _detect_task_buffers can diff buffer identities + # across all tasks that share the same compiled graph. + _key_for: dict[str, tuple[int, ...]] = {} + _groups: defaultdict[tuple[int, ...], list[str]] = defaultdict(list) + for task_key in self.model_keys: + sk = _get_model_structure_key(wrapper_mod.model[task_key]) + _key_for[task_key] = sk + _groups[sk].append(task_key) + + _task_bufs_for: dict[str, dict[str, torch.Tensor]] = {} + for sk, group_keys in _groups.items(): + group_models = [wrapper_mod.model[k] for k in group_keys] + for task_key in group_keys: + _task_bufs_for[task_key] = _detect_task_buffers( + wrapper_mod.model[task_key], group_models + ) + # structure_key -> (compiled_lower, task_buf_order) - # Shared-fitting tasks produce the same structure key so only the first - # task triggers make_fx + torch.compile; the rest reuse the result. - _compiled_by_structure: dict[int, tuple] = {} + # Tasks with the same structure key (same descriptor + shared fitting) + # reuse the compiled graph; different descriptor or fitting → distinct key. + _compiled_by_structure: dict[tuple[int, ...], tuple] = {} for task_key in self.model_keys: model = wrapper_mod.model[task_key] @@ -1124,8 +1179,8 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: task_key, ) - structure_key = _get_model_structure_key(model) - task_bufs = _get_task_buffers(model) + structure_key = _key_for[task_key] + task_bufs = _task_bufs_for[task_key] if structure_key in _compiled_by_structure: # Shared structure: reuse the already-compiled graph. diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 07bbf2c06a..eee425d80f 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -1298,5 +1298,233 @@ def test_compile_warns_dpa1_with_attention(self) -> None: self.assertIsInstance(trainer.wrapper.model["Default"], _CompiledModel) +class TestCompiledSharedFittingDifferentDescriptor(unittest.TestCase): + """Regression test: shared fitting with different descriptors gets distinct compiled graphs. + + Before the fix, ``_get_model_structure_key`` returned the id of the first + fitting-net child without including the descriptor. Two tasks sharing a + fitting net but using different descriptors (different rcut / sel, which + bake different smooth-cutoff constants into the traced graph) received the + same structure key — task_2 silently reused task_1's compiled graph and + produced wrong predictions. + + The fix includes ``id(descriptor)`` in the key so each task with a + distinct descriptor gets its own compiled graph. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def _make_config(self, enable_compile: bool) -> tuple[dict, object]: + """Multi-task config: shared fitting_net, DIFFERENT descriptors per task.""" + from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, + ) + + data_dir_0 = os.path.join(self.data_dir, "data_0") + config = { + "model": { + "shared_dict": { + "my_type_map": ["O", "H"], + "my_fitting": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + "dim_case_embd": 2, + "precision": "float64", + }, + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "precision": "float64", + "seed": 1, + }, + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": { + "type": "se_e2_a", + "sel": [4, 8], + "rcut_smth": 0.30, + "rcut": 2.50, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "precision": "float64", + "seed": 2, + }, + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": {"model_1": 0.5, "model_2": 0.5}, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": {"systems": [data_dir_0], "batch_size": 1}, + "validation_data": { + "systems": [data_dir_0], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": {"systems": [data_dir_0], "batch_size": 1}, + "validation_data": { + "systems": [data_dir_0], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + }, + } + if enable_compile: + config["training"]["enable_compile"] = True + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + return config, shared_links + + def test_compiled_matches_eager_per_task(self) -> None: + """Compiled output for each task must match its own eager output. + + With different descriptors, tasks must get separate compiled graphs. + Before the fix, task_2 reused task_1's compiled graph (rcut=3.0 baked + in), yielding wrong predictions for task_2 (rcut=2.5). + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + _get_model_structure_key, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_diff_desc_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config_uc, shared_links_uc = self._make_config(enable_compile=False) + config_c, shared_links_c = self._make_config(enable_compile=True) + + trainer_uc = get_trainer(config_uc, shared_links=shared_links_uc) + trainer_c = get_trainer(config_c, shared_links=shared_links_c) + + for mk in ("model_1", "model_2"): + self.assertIsInstance( + trainer_c.wrapper.model[mk], + _CompiledModel, + f"{mk} was not compiled", + ) + + # Different descriptors → different structure keys → separate graphs. + key_1 = _get_model_structure_key( + trainer_c.wrapper.model["model_1"].original_model + ) + key_2 = _get_model_structure_key( + trainer_c.wrapper.model["model_2"].original_model + ) + self.assertNotEqual( + key_1, + key_2, + "Tasks with different descriptors must get different structure keys", + ) + + # Sync weights so compiled and uncompiled start from the same state. + for mk in ("model_1", "model_2"): + trainer_c.wrapper.model[mk].original_model.load_state_dict( + trainer_uc.wrapper.model[mk].state_dict() + ) + + for mk in ("model_1", "model_2"): + inp_dict, label_dict = trainer_uc.get_data( + is_train=True, task_key=mk + ) + cur_lr = trainer_uc.scheduler.get_last_lr()[0] + + pred_uc, loss_uc, _ = trainer_uc.wrapper( + **inp_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=mk, + ) + pred_c, loss_c, _ = trainer_c.wrapper( + **inp_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=mk, + ) + + for key in ("atom_energy", "energy", "force"): + torch.testing.assert_close( + pred_c[key], + pred_uc[key], + atol=1e-10, + rtol=1e-10, + msg=f"{mk}/{key}: compiled vs eager mismatch", + ) + torch.testing.assert_close( + loss_c, + loss_uc, + atol=1e-10, + rtol=1e-10, + msg=f"{mk}/loss: compiled vs eager mismatch", + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + if __name__ == "__main__": unittest.main() From 7499321b07e0f3b650a4d18f76bc373bb31a7e61 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Fri, 29 May 2026 18:18:10 +0800 Subject: [PATCH 08/14] chore: clean up --- deepmd/pt_expt/train/training.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 97f4d75fd5..117fdb2871 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -143,9 +143,20 @@ def _get_model_structure_key(model: torch.nn.Module) -> tuple[int, ...]: Python objects across tasks, so ``id(first_child)`` is equal for all shared tasks and unique across unrelated models. """ + # Use the first shared parameter tensor's id rather than the descriptor + # object's id: share_params makes descriptor *parameters* the same Python + # objects across tasks while the descriptor modules remain distinct. + # Two descriptors sharing params therefore collapse to the same key here, + # which is exactly what we want (same compiled graph). Truly independent + # descriptors have distinct param objects and get distinct keys. descriptor_id: int = 0 try: - descriptor_id = id(model.get_descriptor()) + desc = model.get_descriptor() + for _, p in desc.named_parameters(): + descriptor_id = id(p) + break + else: + descriptor_id = id(desc) except AttributeError: pass @@ -374,7 +385,7 @@ def fn( # fitting-net buffers. originals: dict[str, torch.Tensor | None] = {} if task_buf_order: - for name, val in zip(task_buf_order, task_buf_vals): + for name, val in zip(task_buf_order, task_buf_vals, strict=True): if name.startswith(_AM_PREFIX): actual = name[len(_AM_PREFIX) :] if _atomic_model is not None: @@ -571,8 +582,12 @@ def forward( else: _vals.append(getattr(_fitting, _name)) task_buf_vals: tuple = tuple(_vals) - except AttributeError: - task_buf_vals = () + except AttributeError as exc: + raise RuntimeError( + f"Compiled graph expects task buffers {self._task_buf_order!r} " + "but they could not be retrieved from the model. " + "This is a bug in the compile path." + ) from exc else: task_buf_vals = () result = self.compiled_forward_lower( @@ -1142,7 +1157,7 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: _groups[sk].append(task_key) _task_bufs_for: dict[str, dict[str, torch.Tensor]] = {} - for sk, group_keys in _groups.items(): + for group_keys in _groups.values(): group_models = [wrapper_mod.model[k] for k in group_keys] for task_key in group_keys: _task_bufs_for[task_key] = _detect_task_buffers( @@ -1406,7 +1421,7 @@ def run(self) -> None: input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) cur_lr_sched = self.scheduler.get_last_lr()[0] - model_pred, loss, more_loss = self.wrapper( + _model_pred, loss, more_loss = self.wrapper( **input_dict, cur_lr=cur_lr_sched, label=label_dict, From 764d0a605bb2efb20c8749ac0ce2a5943ff20982 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:15:04 +0800 Subject: [PATCH 09/14] fix: update model structure key to use frozenset for descriptor identity --- deepmd/pt_expt/train/training.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 117fdb2871..2ae0915617 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -131,7 +131,7 @@ def _detect_task_buffers( return result -def _get_model_structure_key(model: torch.nn.Module) -> tuple[int, ...]: +def _get_model_structure_key(model: torch.nn.Module) -> tuple: """Return a key that is identical iff two tasks can safely share a compiled graph. The key captures both the descriptor identity and the fitting-net @@ -139,34 +139,36 @@ def _get_model_structure_key(model: torch.nn.Module) -> tuple[int, ...]: descriptors (which bake distinct descriptor constants into the traced graph) are never assigned the same compiled graph. + Descriptor identity is determined by the frozenset of all parameter + tensor ids rather than the descriptor module id. ``share_params`` leaves + descriptor modules as distinct Python objects while making their parameter + tensors the same objects. Using the full frozenset (not just the first + parameter) correctly handles partial sharing: e.g. shared_level==1 on + DPA1/DPA2/DPA3/SE_T_TEBD shares only the type-embedding while leaving + the main block (se_atten, repinit/repformers, repflows, se_ttebd, …) + task-local. Two partially-shared descriptors therefore produce different + frozensets and receive separate compiled graphs, whereas fully-shared + descriptors produce identical frozensets and safely reuse one graph. + After ``share_params``, the fitting net's child sub-modules are the same Python objects across tasks, so ``id(first_child)`` is equal for all shared tasks and unique across unrelated models. """ - # Use the first shared parameter tensor's id rather than the descriptor - # object's id: share_params makes descriptor *parameters* the same Python - # objects across tasks while the descriptor modules remain distinct. - # Two descriptors sharing params therefore collapse to the same key here, - # which is exactly what we want (same compiled graph). Truly independent - # descriptors have distinct param objects and get distinct keys. - descriptor_id: int = 0 + descriptor_key: frozenset | int = 0 try: desc = model.get_descriptor() - for _, p in desc.named_parameters(): - descriptor_id = id(p) - break - else: - descriptor_id = id(desc) + param_ids = frozenset(id(p) for p in desc.parameters()) + descriptor_key = param_ids if param_ids else id(desc) except AttributeError: pass try: fitting = model.get_fitting_net() for _, child in fitting.named_children(): - return (descriptor_id, id(child)) + return (descriptor_key, id(child)) except AttributeError: pass - return (descriptor_id, id(model)) + return (descriptor_key, id(model)) # --------------------------------------------------------------------------- From 5e7072918c646d77eac13916f938e6619c059263 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:35:12 +0800 Subject: [PATCH 10/14] fix: nccl timeout --- deepmd/pt_expt/train/training.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 2ae0915617..15251284d0 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -139,28 +139,30 @@ def _get_model_structure_key(model: torch.nn.Module) -> tuple: descriptors (which bake distinct descriptor constants into the traced graph) are never assigned the same compiled graph. - Descriptor identity is determined by the frozenset of all parameter - tensor ids rather than the descriptor module id. ``share_params`` leaves - descriptor modules as distinct Python objects while making their parameter - tensors the same objects. Using the full frozenset (not just the first - parameter) correctly handles partial sharing: e.g. shared_level==1 on - DPA1/DPA2/DPA3/SE_T_TEBD shares only the type-embedding while leaving - the main block (se_atten, repinit/repformers, repflows, se_ttebd, …) - task-local. Two partially-shared descriptors therefore produce different - frozensets and receive separate compiled graphs, whereas fully-shared - descriptors produce identical frozensets and safely reuse one graph. + Descriptor identity is determined by the tuple of ids of the descriptor's + direct child modules. ``share_params`` replaces submodule references + in-place (``self._modules[k] = base._modules[k]``), so after full sharing + (level 0) all direct children of two descriptor instances are the same + Python objects → same id tuple → same structure key. After partial + sharing (level 1, type-embedding only) the main block (se_atten, + repflows, repinit/repformers, …) is a different object → different tuple + → separate compiled graph. Using child module ids rather than parameter + ids avoids iterating thousands of tensors while remaining correct for all + pt_expt descriptor types. After ``share_params``, the fitting net's child sub-modules are the same Python objects across tasks, so ``id(first_child)`` is equal for all shared tasks and unique across unrelated models. """ - descriptor_key: frozenset | int = 0 + descriptor_key: tuple try: desc = model.get_descriptor() - param_ids = frozenset(id(p) for p in desc.parameters()) - descriptor_key = param_ids if param_ids else id(desc) + # Tuple of direct-child module ids: same for fully-shared descriptors, + # different for partially-shared or independent descriptors. + child_ids = tuple(id(m) for _, m in desc.named_children()) + descriptor_key = child_ids if child_ids else (id(desc),) except AttributeError: - pass + descriptor_key = () try: fitting = model.get_fitting_net() From 33a39be7b75c0eb8ce34c62a723c1ea167a44bd8 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:23:44 +0800 Subject: [PATCH 11/14] chore: revert changes + add warning --- deepmd/pt_expt/train/training.py | 72 +++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 15251284d0..e5182a722f 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -131,7 +131,7 @@ def _detect_task_buffers( return result -def _get_model_structure_key(model: torch.nn.Module) -> tuple: +def _get_model_structure_key(model: torch.nn.Module) -> tuple[int, ...]: """Return a key that is identical iff two tasks can safely share a compiled graph. The key captures both the descriptor identity and the fitting-net @@ -139,38 +139,36 @@ def _get_model_structure_key(model: torch.nn.Module) -> tuple: descriptors (which bake distinct descriptor constants into the traced graph) are never assigned the same compiled graph. - Descriptor identity is determined by the tuple of ids of the descriptor's - direct child modules. ``share_params`` replaces submodule references - in-place (``self._modules[k] = base._modules[k]``), so after full sharing - (level 0) all direct children of two descriptor instances are the same - Python objects → same id tuple → same structure key. After partial - sharing (level 1, type-embedding only) the main block (se_atten, - repflows, repinit/repformers, …) is a different object → different tuple - → separate compiled graph. Using child module ids rather than parameter - ids avoids iterating thousands of tensors while remaining correct for all - pt_expt descriptor types. + Descriptor identity uses the id of the first shared parameter tensor. + ``share_params`` makes descriptor *parameters* the same Python objects + across tasks while the descriptor modules remain distinct. Two + descriptors sharing params therefore collapse to the same key here. + Partial sharing (shared_level=1, type-embedding only) is detected in + ``_compile_model`` and raises an explicit error rather than silently + producing a wrong compiled graph. After ``share_params``, the fitting net's child sub-modules are the same Python objects across tasks, so ``id(first_child)`` is equal for all shared tasks and unique across unrelated models. """ - descriptor_key: tuple + descriptor_id: int = 0 try: desc = model.get_descriptor() - # Tuple of direct-child module ids: same for fully-shared descriptors, - # different for partially-shared or independent descriptors. - child_ids = tuple(id(m) for _, m in desc.named_children()) - descriptor_key = child_ids if child_ids else (id(desc),) + for _, p in desc.named_parameters(): + descriptor_id = id(p) + break + else: + descriptor_id = id(desc) except AttributeError: - descriptor_key = () + pass try: fitting = model.get_fitting_net() for _, child in fitting.named_children(): - return (descriptor_key, id(child)) + return (descriptor_id, id(child)) except AttributeError: pass - return (descriptor_key, id(model)) + return (descriptor_id, id(model)) # --------------------------------------------------------------------------- @@ -1160,6 +1158,42 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: _key_for[task_key] = sk _groups[sk].append(task_key) + # Warn if tasks share a compiled graph but have partially shared + # descriptors (e.g. shared_level=1: type_embedding shared, main block + # task-local). The structure key uses the first descriptor parameter + # id; when that parameter comes from the shared type_embedding, partial + # sharing is indistinguishable from full sharing here, and the compiled + # graph will bake the first task's main-block constants for all tasks. + # This combination is unsupported — use shared_level=0 or disable compile. + for group_keys in _groups.values(): + if len(group_keys) < 2: + continue + try: + base_ids = frozenset( + id(p) + for p in wrapper_mod.model[group_keys[0]].get_descriptor().parameters() + ) + except AttributeError: + continue + for other_key in group_keys[1:]: + try: + other_ids = frozenset( + id(p) + for p in wrapper_mod.model[other_key].get_descriptor().parameters() + ) + except AttributeError: + continue + if base_ids != other_ids: + log.warning( + "Tasks %r and %r share a compiled graph but have partially " + "shared descriptors (e.g. shared_level=1). The compiled graph " + "bakes the first task's descriptor constants and will produce " + "wrong results for subsequent tasks. " + "Use shared_level=0 or set 'enable_compile: false'.", + group_keys[0], + other_key, + ) + _task_bufs_for: dict[str, dict[str, torch.Tensor]] = {} for group_keys in _groups.values(): group_models = [wrapper_mod.model[k] for k in group_keys] From c86fe9db031c9e80f1f27e99cad07d613ad1fffa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jun 2026 09:34:32 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt_expt/train/training.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index e5182a722f..a4e576f00d 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -1171,7 +1171,9 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: try: base_ids = frozenset( id(p) - for p in wrapper_mod.model[group_keys[0]].get_descriptor().parameters() + for p in wrapper_mod.model[group_keys[0]] + .get_descriptor() + .parameters() ) except AttributeError: continue @@ -1179,7 +1181,9 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: try: other_ids = frozenset( id(p) - for p in wrapper_mod.model[other_key].get_descriptor().parameters() + for p in wrapper_mod.model[other_key] + .get_descriptor() + .parameters() ) except AttributeError: continue From e50766725e23846a4acc1ab9e904b0a3183fc936 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 2 Jun 2026 10:31:52 +0800 Subject: [PATCH 13/14] fix: partial share check --- deepmd/pt_expt/descriptor/dpa2.py | 2 ++ deepmd/pt_expt/descriptor/dpa3.py | 4 +++ deepmd/pt_expt/train/training.py | 59 +++++++++++-------------------- 3 files changed, 26 insertions(+), 39 deletions(-) diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py index 0c8535ab00..01a7dfde32 100644 --- a/deepmd/pt_expt/descriptor/dpa2.py +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -69,6 +69,8 @@ def share_params( "g1_shape_tranform" ] self._modules["repformers"] = base_class._modules["repformers"] + if "tebd_transform" in base_class._modules: + self._modules["tebd_transform"] = base_class._modules["tebd_transform"] elif shared_level == 1: self._modules["type_embedding"] = base_class._modules["type_embedding"] else: diff --git a/deepmd/pt_expt/descriptor/dpa3.py b/deepmd/pt_expt/descriptor/dpa3.py index fec047e3fd..2492ba54b2 100644 --- a/deepmd/pt_expt/descriptor/dpa3.py +++ b/deepmd/pt_expt/descriptor/dpa3.py @@ -40,6 +40,10 @@ def share_params( if not resume: merge_env_stat(base_class.repflows, self.repflows, model_prob) self._modules["repflows"] = base_class._modules["repflows"] + if self.add_chg_spin_ebd: + for key in ("chg_embedding", "spin_embedding", "mix_cs_mlp"): + if key in base_class._modules: + self._modules[key] = base_class._modules[key] elif shared_level == 1: self._modules["type_embedding"] = base_class._modules["type_embedding"] else: diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index a4e576f00d..34e3ae9e93 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -852,6 +852,7 @@ def _make_sample( self.start_step = 0 # Shared params (multi-task) ------------------------------------------ + self._shared_links = shared_links if shared_links is not None: _data_stat_protect = np.array( [ @@ -1158,45 +1159,25 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: _key_for[task_key] = sk _groups[sk].append(task_key) - # Warn if tasks share a compiled graph but have partially shared - # descriptors (e.g. shared_level=1: type_embedding shared, main block - # task-local). The structure key uses the first descriptor parameter - # id; when that parameter comes from the shared type_embedding, partial - # sharing is indistinguishable from full sharing here, and the compiled - # graph will bake the first task's main-block constants for all tasks. - # This combination is unsupported — use shared_level=0 or disable compile. - for group_keys in _groups.values(): - if len(group_keys) < 2: - continue - try: - base_ids = frozenset( - id(p) - for p in wrapper_mod.model[group_keys[0]] - .get_descriptor() - .parameters() - ) - except AttributeError: - continue - for other_key in group_keys[1:]: - try: - other_ids = frozenset( - id(p) - for p in wrapper_mod.model[other_key] - .get_descriptor() - .parameters() - ) - except AttributeError: - continue - if base_ids != other_ids: - log.warning( - "Tasks %r and %r share a compiled graph but have partially " - "shared descriptors (e.g. shared_level=1). The compiled graph " - "bakes the first task's descriptor constants and will produce " - "wrong results for subsequent tasks. " - "Use shared_level=0 or set 'enable_compile: false'.", - group_keys[0], - other_key, - ) + # Reject partial descriptor sharing (shared_level > 0) with torch.compile. + # The compiled graph bakes the first task's descriptor constants, so tasks + # sharing a graph must have identical descriptor parameters. partial sharing + # (e.g. shared_level=1, type_embedding shared but main block task-local) + # violates this invariant. Check directly from the config rather than + # via parameter-identity heuristics. + if self._shared_links is not None: + for info in self._shared_links.values(): + for link_item in info["links"]: + if "descriptor" in link_item["shared_type"] and int( + link_item["shared_level"] + ) > 0: + raise RuntimeError( + f"torch.compile is incompatible with partial descriptor " + f"sharing (task {link_item['model_key']!r}, " + f"shared_level={link_item['shared_level']}). " + f"Use shared_level=0 for all descriptors, " + f"or set 'enable_compile: false'." + ) _task_bufs_for: dict[str, dict[str, torch.Tensor]] = {} for group_keys in _groups.values(): From 492d16ae389063495e0ece0ec1a767a5e47d5fc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 02:32:44 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt_expt/train/training.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 34e3ae9e93..5b7dc15154 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -1168,9 +1168,10 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: if self._shared_links is not None: for info in self._shared_links.values(): for link_item in info["links"]: - if "descriptor" in link_item["shared_type"] and int( - link_item["shared_level"] - ) > 0: + if ( + "descriptor" in link_item["shared_type"] + and int(link_item["shared_level"]) > 0 + ): raise RuntimeError( f"torch.compile is incompatible with partial descriptor " f"sharing (task {link_item['model_key']!r}, "