diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1059af0be6..c15415cac0 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -359,6 +359,18 @@ def __init__( self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower + def __getattr__(self, name: str) -> Any: + # Delegate unknown lookups to original_model so that callers such as + # share_params (which calls .get_descriptor(), .atomic_model, etc.) and + # _compile_model (which calls .get_rcut(), .get_sel()) keep working + # transparently after compilation replaces the plain model with this + # wrapper. nn.Module.__getattr__ is tried first so registered + # submodules / parameters / buffers are never shadowed. + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.original_model, name) + def forward( self, coord: torch.Tensor, @@ -1119,10 +1131,26 @@ def _broadcast_model_stat(model: torch.nn.Module) -> None: def save_checkpoint(self, step: int) -> None: self._unwrapped.train_infos["step"] = step - state = { - "model": self._unwrapped.state_dict(), - "optimizer": self.optimizer.state_dict(), - } + # When compiled, wrapper.model[key] is _CompiledModel whose state_dict + # uses keys like "original_model.*". Restart would load into a plain + # ModelWrapper expecting "model.{key}.*" keys → hard crash. Temporarily + # swap each _CompiledModel back to its original_model so the saved keys + # match what a fresh __init__ expects, then restore. + wrapper = self._unwrapped + compiled_backup: dict[str, _CompiledModel] = {} + for task_key in list(wrapper.model.keys()): + m = wrapper.model[task_key] + if isinstance(m, _CompiledModel): + compiled_backup[task_key] = m + wrapper.model[task_key] = m.original_model + try: + state = { + "model": wrapper.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + finally: + for task_key, compiled in compiled_backup.items(): + wrapper.model[task_key] = compiled ckpt_path = f"{self.save_ckpt}-{step}.pt" torch.save(state, ckpt_path) # symlink latest diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 07bbf2c06a..642b7ad491 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -299,6 +299,70 @@ def test_training_loop_compiled_silu(self) -> None: self._run_training(config) +class TestCompiledModelGetattr(unittest.TestCase): + """Unit tests for _CompiledModel attribute delegation. + + These tests do not require example data or torch.compile — they use a + lightweight mock original_model and a no-op compiled_forward_lower to + verify that __getattr__ correctly forwards unknown attributes/methods to + the wrapped original model. + """ + + def _make_compiled_model(self): + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + class _FakeForwardLower(torch.nn.Module): + def forward(self, *a, **kw): + pass + + class _FakeModel(torch.nn.Module): + def get_rcut(self): + return 3.0 + + def get_type_map(self): + return ["O", "H"] + + @property + def atomic_model(self): + return self + + def get_descriptor(self): + return self + + return _CompiledModel(_FakeModel(), _FakeForwardLower()) + + def test_delegates_method(self) -> None: + """Unknown method calls are forwarded to original_model.""" + cm = self._make_compiled_model() + self.assertAlmostEqual(cm.get_rcut(), 3.0) + + def test_delegates_method_returning_list(self) -> None: + """Methods returning non-scalar values are forwarded correctly.""" + cm = self._make_compiled_model() + self.assertEqual(cm.get_type_map(), ["O", "H"]) + + def test_delegates_property(self) -> None: + """Property access is forwarded to original_model.""" + cm = self._make_compiled_model() + self.assertIsNotNone(cm.atomic_model) + + def test_own_attrs_not_delegated(self) -> None: + """Attributes owned by _CompiledModel itself are NOT delegated.""" + cm = self._make_compiled_model() + # original_model and compiled_forward_lower are registered submodules + # of _CompiledModel — they must not fall through to delegation. + self.assertIsInstance(cm.original_model, torch.nn.Module) + self.assertIsInstance(cm.compiled_forward_lower, torch.nn.Module) + + def test_missing_attr_raises(self) -> None: + """Accessing an attribute missing from both wrapper and original raises.""" + cm = self._make_compiled_model() + with self.assertRaises(AttributeError): + _ = cm.nonexistent_attribute_xyz + + class TestCompiledDynamicShapes(unittest.TestCase): """Test that _CompiledModel handles varying nall via dynamic shapes.""" @@ -716,6 +780,100 @@ def test_init_model(self) -> None: finally: shutil.rmtree(tmpdir, ignore_errors=True) + def test_restart_from_compiled_checkpoint(self) -> None: + """Train WITH compile enabled, restart from the compiled checkpoint. + + Regression test for the state-dict key mismatch bug: save_checkpoint + previously saved _CompiledModel keys ("original_model.*"), which made + load_state_dict fail on restart because a fresh ModelWrapper expects + plain model keys ("model.Default.*"). + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_compiled_restart_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train 5 steps WITH compile enabled. + # Do NOT use _train_and_get_ckpt here — we need the trainer + # object to assert that compilation actually happened. + # Without this check the test is vacuous: if torch.compile + # silently falls back to eager the checkpoint keys are already + # plain and the bug path is never exercised. + config = _make_config(self.data_dir, numb_steps=5) + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer1 = get_trainer(config) + self.assertIsInstance( + trainer1.wrapper.model["Default"], + _CompiledModel, + "Phase-1 trainer did not produce a _CompiledModel; " + "torch.compile may have silently fallen back to eager — " + "the bug path is not exercised.", + ) + trainer1.run() + + ckpt_path = os.path.join(tmpdir, "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Checkpoint not created") + + # Primary assertion: the checkpoint must NOT contain + # "_CompiledModel wrapper" keys ("original_model.*"). + raw = torch.load(ckpt_path, map_location="cpu", weights_only=True) + ckpt_keys = list(raw["model"].keys()) + self.assertFalse( + any("original_model" in k for k in ckpt_keys), + f"Checkpoint has _CompiledModel wrapper keys — " + f"save_checkpoint must unwrap before serialising: " + f"{[k for k in ckpt_keys if 'original_model' in k]}", + ) + + # Secondary assertion: the saved state dict must load cleanly + # into a fresh *uncompiled* ModelWrapper (strict=False so we + # can distinguish unexpected vs missing keys clearly). + config_nc = _make_config(self.data_dir, numb_steps=10) + config_nc = update_deepmd_input(config_nc, warning=False) + config_nc = normalize(config_nc) + fresh = get_trainer(config_nc) + missing, unexpected = fresh._unwrapped.load_state_dict( + raw["model"], strict=False + ) + self.assertEqual( + unexpected, + [], + f"Unexpected keys when loading compiled ckpt into plain " + f"model (indicates _CompiledModel wrapper keys leaked): " + f"{unexpected}", + ) + self.assertEqual( + missing, + [], + f"Missing keys when loading compiled ckpt into plain " + f"model: {missing}", + ) + + # Phase 2: restart from the compiled checkpoint + config2 = _make_config(self.data_dir, numb_steps=10) + config2["training"]["enable_compile"] = True + config2 = update_deepmd_input(config2, warning=False) + config2 = normalize(config2) + trainer2 = get_trainer(config2, restart_model=ckpt_path) + + self.assertEqual(trainer2.start_step, 5) + self.assertIsInstance(trainer2.wrapper.model["Default"], _CompiledModel) + trainer2.run() + + with open(os.path.join(tmpdir, "lcurve.out")) as f: + lines = [ln for ln in f.readlines() if not ln.startswith("#")] + self.assertGreater(len(lines), 0) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + def test_restart_with_compile(self) -> None: """Train uncompiled, restart with compile enabled.""" from deepmd.pt_expt.train.training import (