diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 8f40600ffc..60c1e72e4f 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -366,10 +366,15 @@ def _load_pt(self, model_file: str, head: str | None = None) -> None: # eager inference). Drop the latter and unwrap the former. cleaned: dict[str, Any] = {} compiled_marker = ".compiled_forward_lower." + # Per-task buffer copies registered on _CompiledModel (bias_atom_e, + # case_embd) — real values live on the original model's fitting net. + task_buf_marker = "._task_" wrapper_infix = ".original_model." for key, value in state_dict.items(): if compiled_marker in key: continue + if task_buf_marker in key: + continue if wrapper_infix in key: key = key.replace(wrapper_infix, ".", 1) cleaned[key] = value diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1059af0be6..97f4d75fd5 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -70,6 +70,93 @@ log = logging.getLogger(__name__) +# Buffer names in atomic_model that are per-task (energy/output statistics). +# These live one level above the fitting net and are not reached by +# fitting-net share_params. They are always promoted to FX placeholders +# because model_change_out_bias may replace them out-of-place after +# compilation, so the compiled forward must read them fresh each call. +_ATOMIC_MODEL_TASK_BUFFER_NAMES: tuple[str, ...] = ("out_bias", "out_std") + +# Prefix used in task_buf_order keys to distinguish atomic_model buffers +# from fitting-net buffers. +_AM_PREFIX = "am/" + + +def _detect_task_buffers( + model: torch.nn.Module, + group_models: list["torch.nn.Module"], +) -> dict[str, torch.Tensor]: + """Collect per-task buffers to promote to FX placeholders. + + Fitting-net buffers are auto-detected by identity diff across + *group_models* (all tasks that share this model's structure key after + ``share_params``). Any buffer that is a *different* Python object in at + least one other group member is task-specific and gets promoted. + + Atomic-model buffers listed in ``_ATOMIC_MODEL_TASK_BUFFER_NAMES`` are + always promoted because ``model_change_out_bias`` may replace them + out-of-place after compilation. + """ + result: dict[str, torch.Tensor] = {} + + # Auto-detect fitting-net task buffers by identity diff across the group. + try: + fitting = model.get_fitting_net() + for name, val in fitting._buffers.items(): + if val is None or not torch.is_tensor(val): + continue + for other in group_models: + if other is model: + continue + try: + other_val = other.get_fitting_net()._buffers.get(name) + if other_val is not val: + result[name] = val.detach().clone() + break + except AttributeError: + pass + except AttributeError: + pass + + # Atomic-model task buffers (always promote). + try: + am = model.atomic_model + for name in _ATOMIC_MODEL_TASK_BUFFER_NAMES: + val = am._buffers.get(name) + if val is not None and torch.is_tensor(val): + result[_AM_PREFIX + name] = val.detach().clone() + except AttributeError: + pass + + return result + + +def _get_model_structure_key(model: torch.nn.Module) -> tuple[int, ...]: + """Return a key that is identical iff two tasks can safely share a compiled graph. + + The key captures both the descriptor identity and the fitting-net + structure so that tasks sharing a fitting net but using *different* + descriptors (which bake distinct descriptor constants into the traced + graph) are never assigned the same compiled graph. + + After ``share_params``, the fitting net's child sub-modules are the same + Python objects across tasks, so ``id(first_child)`` is equal for all + shared tasks and unique across unrelated models. + """ + descriptor_id: int = 0 + try: + descriptor_id = id(model.get_descriptor()) + except AttributeError: + pass + + try: + fitting = model.get_fitting_net() + for _, child in fitting.named_children(): + return (descriptor_id, id(child)) + except AttributeError: + pass + return (descriptor_id, id(model)) + # --------------------------------------------------------------------------- # Helper: loss factory (reused from pt) @@ -214,7 +301,8 @@ def _trace_and_compile( aparam: torch.Tensor | None, compile_opts: dict[str, Any] | None = None, charge_spin: torch.Tensor | None = None, -) -> torch.nn.Module: + task_buffers: dict[str, torch.Tensor] | None = None, +) -> tuple[torch.nn.Module, tuple[str, ...]]: """Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True. Parameters @@ -226,11 +314,18 @@ def _trace_and_compile( compile_opts : dict or None User-supplied inductor options. These are merged on top of the built-in defaults (user values take precedence). + task_buffers : dict or None + Per-task buffers (e.g. ``bias_atom_e``, ``case_embd``, ``out_bias``, + ``out_std``) detected by ``_detect_task_buffers``. These are promoted + to explicit FX ``placeholder`` nodes so the compiled graph is reusable + across tasks that share the same structure key. Returns ------- - torch.nn.Module + compiled : torch.nn.Module The compiled ``forward_lower`` callable. + task_buf_order : tuple[str, ...] + Ordered names of the promoted buffers (empty when none). """ from torch.fx.experimental.proxy_tensor import ( make_fx, @@ -244,6 +339,24 @@ def _trace_and_compile( # backprop cannot reach the weights and force RMSE never decreases. model.train() + task_buf_order: tuple[str, ...] = tuple(task_buffers.keys()) if task_buffers else () + task_buf_vals_trace: tuple[torch.Tensor, ...] = ( + tuple(task_buffers[k] for k in task_buf_order) if task_buffers else () + ) + + # Resolve fitting net and atomic_model once for buffer patching inside fn. + _fitting: torch.nn.Module | None = None + _atomic_model: torch.nn.Module | None = None + if task_buf_order: + try: + _fitting = model.get_fitting_net() + except AttributeError: + pass # no fitting net → no fitting-net buffers to patch + try: + _atomic_model = model.atomic_model + except AttributeError: + pass # no atomic_model → no atomic-model buffers to patch + def fn( extended_coord: torch.Tensor, extended_atype: torch.Tensor, @@ -252,17 +365,44 @@ def fn( fparam: torch.Tensor | None, aparam: torch.Tensor | None, charge_spin: torch.Tensor | None, + *task_buf_vals: torch.Tensor, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) - return model.forward_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fparam, - aparam=aparam, - charge_spin=charge_spin, - ) + # Temporarily patch task-specific buffers with the proxy tensors so + # make_fx records them as FX placeholders rather than baked-in constants. + # Keys prefixed with _AM_PREFIX are atomic_model buffers; the rest are + # fitting-net buffers. + originals: dict[str, torch.Tensor | None] = {} + if task_buf_order: + for name, val in zip(task_buf_order, task_buf_vals): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + if _atomic_model is not None: + originals[name] = _atomic_model._buffers.get(actual) + _atomic_model._buffers[actual] = val + else: + if _fitting is not None: + originals[name] = _fitting._buffers.get(name) + _fitting._buffers[name] = val + try: + return model.forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + finally: + for name, orig in originals.items(): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + if _atomic_model is not None: + _atomic_model._buffers[actual] = orig + else: + if _fitting is not None: + _fitting._buffers[name] = orig # Pick a trace-time nframes that's unlikely to collide with any other # tensor dim in the graph. The symbolic tracer merges symbols that @@ -309,7 +449,16 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: tracing_mode="symbolic", _allow_non_fake_inputs=True, decomposition_table=decomp_table, - )(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin) + )( + ext_coord, + ext_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, + *task_buf_vals_trace, + ) # make_fx inserts aten.detach.default for saved tensors used in the # decomposed autograd.grad backward ops. These detach nodes break @@ -344,7 +493,7 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: backend="inductor", dynamic=True, options=inductor_options, - ) + ), task_buf_order class _CompiledModel(torch.nn.Module): @@ -354,10 +503,16 @@ def __init__( self, original_model: torch.nn.Module, compiled_forward_lower: torch.nn.Module, + task_buf_order: tuple[str, ...] = (), + task_buffers: dict[str, torch.Tensor] | None = None, ) -> None: super().__init__() self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower + self._task_buf_order = task_buf_order + # task_buffers is intentionally not stored: buffers are read from + # original_model.get_fitting_net() at forward time so that weight + # updates (load_state_dict, optimiser steps) are always reflected. def forward( self, @@ -404,8 +559,31 @@ def forward( ext_coord = ext_coord.reshape(nframes, -1, 3) ext_coord = ext_coord.detach().requires_grad_(True) + if self._task_buf_order: + try: + _fitting = self.original_model.get_fitting_net() + _am = getattr(self.original_model, "atomic_model", None) + _vals: list[torch.Tensor] = [] + for _name in self._task_buf_order: + if _name.startswith(_AM_PREFIX): + _actual = _name[len(_AM_PREFIX) :] + _vals.append(_am._buffers[_actual]) + else: + _vals.append(getattr(_fitting, _name)) + task_buf_vals: tuple = tuple(_vals) + except AttributeError: + task_buf_vals = () + else: + task_buf_vals = () result = self.compiled_forward_lower( - ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin + ext_coord, + ext_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, + *task_buf_vals, ) # Translate forward_lower keys -> forward keys. @@ -947,6 +1125,35 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: else self.wrapper ) + from collections import ( + defaultdict, + ) + + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP + + # Pre-pass: group tasks by structure key and auto-detect per-task buffers. + # Grouping is needed so _detect_task_buffers can diff buffer identities + # across all tasks that share the same compiled graph. + _key_for: dict[str, tuple[int, ...]] = {} + _groups: defaultdict[tuple[int, ...], list[str]] = defaultdict(list) + for task_key in self.model_keys: + sk = _get_model_structure_key(wrapper_mod.model[task_key]) + _key_for[task_key] = sk + _groups[sk].append(task_key) + + _task_bufs_for: dict[str, dict[str, torch.Tensor]] = {} + for sk, group_keys in _groups.items(): + group_models = [wrapper_mod.model[k] for k in group_keys] + for task_key in group_keys: + _task_bufs_for[task_key] = _detect_task_buffers( + wrapper_mod.model[task_key], group_models + ) + + # structure_key -> (compiled_lower, task_buf_order) + # Tasks with the same structure key (same descriptor + shared fitting) + # reuse the compiled graph; different descriptor or fitting → distinct key. + _compiled_by_structure: dict[tuple[int, ...], tuple] = {} + for task_key in self.model_keys: model = wrapper_mod.model[task_key] @@ -957,8 +1164,6 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: # is hardware-dependent. Warn but do not reject — energies # remain well within training tolerance and the user may # accept the trade-off for compile speed. - from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP - descriptor = model.get_descriptor() if isinstance(descriptor, DescrptDPA1DP): n_attn = descriptor.get_numb_attn_layer() @@ -974,54 +1179,71 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: task_key, ) - inp, _ = self.get_data(is_train=True, task_key=task_key) - coord = inp["coord"].detach() - atype = inp["atype"].detach() - box = inp.get("box") - if box is not None: - box = box.detach() - - nframes, nloc = atype.shape[:2] - coord_3d = coord.reshape(nframes, nloc, 3) - box_flat = box.reshape(nframes, 9) if box is not None else None + structure_key = _key_for[task_key] + task_bufs = _task_bufs_for[task_key] - if box_flat is not None: - coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3)) + if structure_key in _compiled_by_structure: + # Shared structure: reuse the already-compiled graph. + compiled_lower, task_buf_order = _compiled_by_structure[structure_key] + log.info( + "Reusing compiled graph for task=%s (shared model structure).", + task_key, + ) else: - coord_norm = coord_3d - - ext_coord, ext_atype, mapping = extend_coord_with_ghosts( - coord_norm, atype, box_flat, model.get_rcut() - ) - nlist_t = build_neighbor_list( - ext_coord, - ext_atype, - nloc, - model.get_rcut(), - model.get_sel(), - distinguish_types=False, - ) - ext_coord = ext_coord.reshape(nframes, -1, 3) + inp, _ = self.get_data(is_train=True, task_key=task_key) + coord = inp["coord"].detach() + atype = inp["atype"].detach() + box = inp.get("box") + if box is not None: + box = box.detach() + + nframes, nloc = atype.shape[:2] + coord_3d = coord.reshape(nframes, nloc, 3) + box_flat = box.reshape(nframes, 9) if box is not None else None + + if box_flat is not None: + coord_norm = normalize_coord( + coord_3d, box_flat.reshape(nframes, 3, 3) + ) + else: + coord_norm = coord_3d - fparam = inp.get("fparam") - aparam = inp.get("aparam") - charge_spin = inp.get("charge_spin") + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype, box_flat, model.get_rcut() + ) + nlist_t = build_neighbor_list( + ext_coord, + ext_atype, + nloc, + model.get_rcut(), + model.get_sel(), + distinguish_types=False, + ) + ext_coord = ext_coord.reshape(nframes, -1, 3) + + fparam = inp.get("fparam") + aparam = inp.get("aparam") + charge_spin = inp.get("charge_spin") + + compiled_lower, task_buf_order = _trace_and_compile( + model, + ext_coord, + ext_atype, + nlist_t, + mapping, + fparam, + aparam, + charge_spin=charge_spin, + task_buffers=task_bufs if task_bufs else None, + compile_opts=compile_opts, + ) + _compiled_by_structure[structure_key] = (compiled_lower, task_buf_order) - compiled_lower = _trace_and_compile( - model, - ext_coord, - ext_atype, - nlist_t, - mapping, - fparam, - aparam, - charge_spin=charge_spin, - compile_opts=compile_opts, + wrapper_mod.model[task_key] = _CompiledModel( + model, compiled_lower, task_buf_order, task_bufs ) - - wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower) log.info( - "Model compiled (task=%s, tracing_mode=symbolic, " + "Model compiled/reused (task=%s, tracing_mode=symbolic, " "dynamic=True, backend=inductor).", task_key, ) @@ -1210,9 +1432,15 @@ def run(self) -> None: self.wrapper.eval() if self.rank == 0: + + def _to_float(v: Any) -> float: + return v.detach().item() if torch.is_tensor(v) else float(v) + if not self.multi_task: train_results = { - k: v for k, v in more_loss.items() if "l2_" not in k + k: _to_float(v) + for k, v in more_loss.items() + if "l2_" not in k } # validation @@ -1233,7 +1461,8 @@ def run(self) -> None: for k, v in _vmore.items(): if "l2_" not in k: valid_results[k] = ( - valid_results.get(k, 0.0) + v * natoms + valid_results.get(k, 0.0) + + _to_float(v) * natoms ) if sum_natoms > 0: valid_results = { @@ -1246,7 +1475,9 @@ def run(self) -> None: # current task already has loss train_results[task_key] = { - k: v for k, v in more_loss.items() if "l2_" not in k + k: _to_float(v) + for k, v in more_loss.items() + if "l2_" not in k } # compute loss for other tasks @@ -1261,7 +1492,9 @@ def run(self) -> None: task_key=_key, ) train_results[_key] = { - k: v for k, v in _more.items() if "l2_" not in k + k: _to_float(v) + for k, v in _more.items() + if "l2_" not in k } # validation for each task @@ -1285,7 +1518,10 @@ def run(self) -> None: _sum_natoms += natoms for k, v in _vmore.items(): if "l2_" not in k: - _vres[k] = _vres.get(k, 0.0) + v * natoms + _vres[k] = ( + _vres.get(k, 0.0) + + _to_float(v) * natoms + ) if _sum_natoms > 0: _vres = { k: v / _sum_natoms for k, v in _vres.items() diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 07bbf2c06a..eee425d80f 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -1298,5 +1298,233 @@ def test_compile_warns_dpa1_with_attention(self) -> None: self.assertIsInstance(trainer.wrapper.model["Default"], _CompiledModel) +class TestCompiledSharedFittingDifferentDescriptor(unittest.TestCase): + """Regression test: shared fitting with different descriptors gets distinct compiled graphs. + + Before the fix, ``_get_model_structure_key`` returned the id of the first + fitting-net child without including the descriptor. Two tasks sharing a + fitting net but using different descriptors (different rcut / sel, which + bake different smooth-cutoff constants into the traced graph) received the + same structure key — task_2 silently reused task_1's compiled graph and + produced wrong predictions. + + The fix includes ``id(descriptor)`` in the key so each task with a + distinct descriptor gets its own compiled graph. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def _make_config(self, enable_compile: bool) -> tuple[dict, object]: + """Multi-task config: shared fitting_net, DIFFERENT descriptors per task.""" + from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, + ) + + data_dir_0 = os.path.join(self.data_dir, "data_0") + config = { + "model": { + "shared_dict": { + "my_type_map": ["O", "H"], + "my_fitting": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + "dim_case_embd": 2, + "precision": "float64", + }, + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "precision": "float64", + "seed": 1, + }, + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": { + "type": "se_e2_a", + "sel": [4, 8], + "rcut_smth": 0.30, + "rcut": 2.50, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "precision": "float64", + "seed": 2, + }, + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": {"model_1": 0.5, "model_2": 0.5}, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": {"systems": [data_dir_0], "batch_size": 1}, + "validation_data": { + "systems": [data_dir_0], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": {"systems": [data_dir_0], "batch_size": 1}, + "validation_data": { + "systems": [data_dir_0], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + }, + } + if enable_compile: + config["training"]["enable_compile"] = True + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + return config, shared_links + + def test_compiled_matches_eager_per_task(self) -> None: + """Compiled output for each task must match its own eager output. + + With different descriptors, tasks must get separate compiled graphs. + Before the fix, task_2 reused task_1's compiled graph (rcut=3.0 baked + in), yielding wrong predictions for task_2 (rcut=2.5). + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + _get_model_structure_key, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_diff_desc_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config_uc, shared_links_uc = self._make_config(enable_compile=False) + config_c, shared_links_c = self._make_config(enable_compile=True) + + trainer_uc = get_trainer(config_uc, shared_links=shared_links_uc) + trainer_c = get_trainer(config_c, shared_links=shared_links_c) + + for mk in ("model_1", "model_2"): + self.assertIsInstance( + trainer_c.wrapper.model[mk], + _CompiledModel, + f"{mk} was not compiled", + ) + + # Different descriptors → different structure keys → separate graphs. + key_1 = _get_model_structure_key( + trainer_c.wrapper.model["model_1"].original_model + ) + key_2 = _get_model_structure_key( + trainer_c.wrapper.model["model_2"].original_model + ) + self.assertNotEqual( + key_1, + key_2, + "Tasks with different descriptors must get different structure keys", + ) + + # Sync weights so compiled and uncompiled start from the same state. + for mk in ("model_1", "model_2"): + trainer_c.wrapper.model[mk].original_model.load_state_dict( + trainer_uc.wrapper.model[mk].state_dict() + ) + + for mk in ("model_1", "model_2"): + inp_dict, label_dict = trainer_uc.get_data( + is_train=True, task_key=mk + ) + cur_lr = trainer_uc.scheduler.get_last_lr()[0] + + pred_uc, loss_uc, _ = trainer_uc.wrapper( + **inp_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=mk, + ) + pred_c, loss_c, _ = trainer_c.wrapper( + **inp_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=mk, + ) + + for key in ("atom_energy", "energy", "force"): + torch.testing.assert_close( + pred_c[key], + pred_uc[key], + atol=1e-10, + rtol=1e-10, + msg=f"{mk}/{key}: compiled vs eager mismatch", + ) + torch.testing.assert_close( + loss_c, + loss_uc, + atol=1e-10, + rtol=1e-10, + msg=f"{mk}/loss: compiled vs eager mismatch", + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + if __name__ == "__main__": unittest.main()