Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,18 @@ def __init__(
self.original_model = original_model
self.compiled_forward_lower = compiled_forward_lower

def __getattr__(self, name: str) -> Any:
# Delegate unknown lookups to original_model so that callers such as
# share_params (which calls .get_descriptor(), .atomic_model, etc.) and
# _compile_model (which calls .get_rcut(), .get_sel()) keep working
# transparently after compilation replaces the plain model with this
# wrapper. nn.Module.__getattr__ is tried first so registered
# submodules / parameters / buffers are never shadowed.
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.original_model, name)

def forward(
self,
coord: torch.Tensor,
Expand Down Expand Up @@ -1119,10 +1131,26 @@ def _broadcast_model_stat(model: torch.nn.Module) -> None:

def save_checkpoint(self, step: int) -> None:
self._unwrapped.train_infos["step"] = step
state = {
"model": self._unwrapped.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
# When compiled, wrapper.model[key] is _CompiledModel whose state_dict
# uses keys like "original_model.*". Restart would load into a plain
# ModelWrapper expecting "model.{key}.*" keys → hard crash. Temporarily
# swap each _CompiledModel back to its original_model so the saved keys
# match what a fresh __init__ expects, then restore.
wrapper = self._unwrapped
compiled_backup: dict[str, _CompiledModel] = {}
for task_key in list(wrapper.model.keys()):
m = wrapper.model[task_key]
if isinstance(m, _CompiledModel):
compiled_backup[task_key] = m
Comment thread
anyangml marked this conversation as resolved.
wrapper.model[task_key] = m.original_model
try:
state = {
"model": wrapper.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
finally:
for task_key, compiled in compiled_backup.items():
wrapper.model[task_key] = compiled
ckpt_path = f"{self.save_ckpt}-{step}.pt"
torch.save(state, ckpt_path)
# symlink latest
Expand Down
158 changes: 158 additions & 0 deletions source/tests/pt_expt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,70 @@ def test_training_loop_compiled_silu(self) -> None:
self._run_training(config)


class TestCompiledModelGetattr(unittest.TestCase):
"""Unit tests for _CompiledModel attribute delegation.

These tests do not require example data or torch.compile — they use a
lightweight mock original_model and a no-op compiled_forward_lower to
verify that __getattr__ correctly forwards unknown attributes/methods to
the wrapped original model.
"""

def _make_compiled_model(self):
from deepmd.pt_expt.train.training import (
_CompiledModel,
)

class _FakeForwardLower(torch.nn.Module):
def forward(self, *a, **kw):
pass

class _FakeModel(torch.nn.Module):
def get_rcut(self):
return 3.0

def get_type_map(self):
return ["O", "H"]

@property
def atomic_model(self):
return self

def get_descriptor(self):
return self

return _CompiledModel(_FakeModel(), _FakeForwardLower())

def test_delegates_method(self) -> None:
"""Unknown method calls are forwarded to original_model."""
cm = self._make_compiled_model()
self.assertAlmostEqual(cm.get_rcut(), 3.0)

def test_delegates_method_returning_list(self) -> None:
"""Methods returning non-scalar values are forwarded correctly."""
cm = self._make_compiled_model()
self.assertEqual(cm.get_type_map(), ["O", "H"])

def test_delegates_property(self) -> None:
"""Property access is forwarded to original_model."""
cm = self._make_compiled_model()
self.assertIsNotNone(cm.atomic_model)

def test_own_attrs_not_delegated(self) -> None:
"""Attributes owned by _CompiledModel itself are NOT delegated."""
cm = self._make_compiled_model()
# original_model and compiled_forward_lower are registered submodules
# of _CompiledModel — they must not fall through to delegation.
self.assertIsInstance(cm.original_model, torch.nn.Module)
self.assertIsInstance(cm.compiled_forward_lower, torch.nn.Module)

def test_missing_attr_raises(self) -> None:
"""Accessing an attribute missing from both wrapper and original raises."""
cm = self._make_compiled_model()
with self.assertRaises(AttributeError):
_ = cm.nonexistent_attribute_xyz


class TestCompiledDynamicShapes(unittest.TestCase):
"""Test that _CompiledModel handles varying nall via dynamic shapes."""

Expand Down Expand Up @@ -716,6 +780,100 @@ def test_init_model(self) -> None:
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

def test_restart_from_compiled_checkpoint(self) -> None:
"""Train WITH compile enabled, restart from the compiled checkpoint.

Regression test for the state-dict key mismatch bug: save_checkpoint
previously saved _CompiledModel keys ("original_model.*"), which made
load_state_dict fail on restart because a fresh ModelWrapper expects
plain model keys ("model.Default.*").
"""
from deepmd.pt_expt.train.training import (
_CompiledModel,
)

tmpdir = tempfile.mkdtemp(prefix="pt_expt_compiled_restart_")
try:
old_cwd = os.getcwd()
os.chdir(tmpdir)
try:
# Phase 1: train 5 steps WITH compile enabled.
# Do NOT use _train_and_get_ckpt here — we need the trainer
# object to assert that compilation actually happened.
# Without this check the test is vacuous: if torch.compile
# silently falls back to eager the checkpoint keys are already
# plain and the bug path is never exercised.
config = _make_config(self.data_dir, numb_steps=5)
config["training"]["enable_compile"] = True
config = update_deepmd_input(config, warning=False)
config = normalize(config)
trainer1 = get_trainer(config)
self.assertIsInstance(
trainer1.wrapper.model["Default"],
_CompiledModel,
"Phase-1 trainer did not produce a _CompiledModel; "
"torch.compile may have silently fallen back to eager — "
"the bug path is not exercised.",
)
trainer1.run()

ckpt_path = os.path.join(tmpdir, "model.ckpt.pt")
self.assertTrue(os.path.exists(ckpt_path), "Checkpoint not created")

# Primary assertion: the checkpoint must NOT contain
# "_CompiledModel wrapper" keys ("original_model.*").
raw = torch.load(ckpt_path, map_location="cpu", weights_only=True)
ckpt_keys = list(raw["model"].keys())
self.assertFalse(
any("original_model" in k for k in ckpt_keys),
f"Checkpoint has _CompiledModel wrapper keys — "
f"save_checkpoint must unwrap before serialising: "
f"{[k for k in ckpt_keys if 'original_model' in k]}",
)

# Secondary assertion: the saved state dict must load cleanly
# into a fresh *uncompiled* ModelWrapper (strict=False so we
# can distinguish unexpected vs missing keys clearly).
config_nc = _make_config(self.data_dir, numb_steps=10)
config_nc = update_deepmd_input(config_nc, warning=False)
config_nc = normalize(config_nc)
fresh = get_trainer(config_nc)
missing, unexpected = fresh._unwrapped.load_state_dict(
raw["model"], strict=False
)
self.assertEqual(
unexpected,
[],
f"Unexpected keys when loading compiled ckpt into plain "
f"model (indicates _CompiledModel wrapper keys leaked): "
f"{unexpected}",
)
self.assertEqual(
missing,
[],
f"Missing keys when loading compiled ckpt into plain "
f"model: {missing}",
)

# Phase 2: restart from the compiled checkpoint
config2 = _make_config(self.data_dir, numb_steps=10)
config2["training"]["enable_compile"] = True
config2 = update_deepmd_input(config2, warning=False)
config2 = normalize(config2)
trainer2 = get_trainer(config2, restart_model=ckpt_path)

self.assertEqual(trainer2.start_step, 5)
self.assertIsInstance(trainer2.wrapper.model["Default"], _CompiledModel)
trainer2.run()

with open(os.path.join(tmpdir, "lcurve.out")) as f:
lines = [ln for ln in f.readlines() if not ln.startswith("#")]
self.assertGreater(len(lines), 0)
finally:
os.chdir(old_cwd)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)

def test_restart_with_compile(self) -> None:
"""Train uncompiled, restart with compile enabled."""
from deepmd.pt_expt.train.training import (
Expand Down
Loading