Skip to content

Commit 57f870f

Browse files
feat(pt): add ema shadow model (#5420)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Exponential Moving Average (EMA) support: maintain EMA shadows, emit EMA-only checkpoints, resume/restore EMA state, and evaluate using EMA weights for dedicated validation and best-model selection. * **Configuration** * New options to enable EMA, set decay and checkpoint retention, and opt into EMA full-validation, with constraints on usage based on training stage. * **Validation** * Full-validation now configurable with customizable log paths, checkpoint naming, state keys, evaluation context, and supports separate EMA evaluation and logs. * **Tests** * Added EMA-focused tests covering checkpoint rotation, cleanup, resume behavior, parameter propagation, and full-validation artifact management. * **Chores** * Refactored checkpoint saving and retention for deterministic pruning and consistent handling of EMA checkpoints. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent eab3419 commit 57f870f

6 files changed

Lines changed: 862 additions & 74 deletions

File tree

deepmd/pt/train/ema.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: LGPL-3.0-or-later
3+
4+
from __future__ import (
5+
annotations,
6+
)
7+
8+
import logging
9+
from contextlib import (
10+
contextmanager,
11+
)
12+
from copy import (
13+
deepcopy,
14+
)
15+
from pathlib import (
16+
Path,
17+
)
18+
from typing import (
19+
TYPE_CHECKING,
20+
Any,
21+
)
22+
23+
import torch
24+
25+
if TYPE_CHECKING:
26+
from collections.abc import (
27+
Iterator,
28+
)
29+
30+
EMA_CHECKPOINT_KEY = "ema"
31+
EMA_DECAY_KEY = "decay"
32+
EMA_MODEL_STATE_KEY = "model"
33+
EMA_VALIDATION_STATE_KEY = "validation_state"
34+
35+
log = logging.getLogger(__name__)
36+
37+
38+
def _append_suffix(path_like: str | Path, suffix: str) -> Path:
39+
"""Append a suffix before the final file suffix when present."""
40+
path = Path(path_like)
41+
if path.suffix:
42+
return path.with_name(f"{path.stem}{suffix}{path.suffix}")
43+
return path.with_name(f"{path.name}{suffix}")
44+
45+
46+
def get_ema_checkpoint_prefix(save_ckpt: str | Path) -> str:
47+
"""Derive the EMA checkpoint prefix from the regular checkpoint prefix."""
48+
return str(_append_suffix(save_ckpt, "_ema"))
49+
50+
51+
def get_ema_validation_log_path(full_val_file: str | Path) -> Path:
52+
"""Derive the EMA validation log path from the regular validation log path."""
53+
return _append_suffix(full_val_file, "_ema")
54+
55+
56+
class ModelEMA:
57+
"""Maintain an exponential moving average of model parameters.
58+
59+
This helper assumes DDP/ZeRO-1 style training where every rank owns the
60+
same full, consistently ordered model parameters. It is not a sharded
61+
parameter EMA implementation.
62+
"""
63+
64+
def __init__(
65+
self,
66+
model: torch.nn.Module | dict[str, torch.nn.Module],
67+
decay: float,
68+
state: dict[str, Any] | None = None,
69+
) -> None:
70+
self.decay = float(decay)
71+
self.shadow_params = self._clone_model_parameters(model)
72+
self.validation_state: dict[str, Any] = {}
73+
if state is not None:
74+
self.load_state_dict(state)
75+
76+
@staticmethod
77+
def _named_model_parameters(
78+
model: torch.nn.Module | dict[str, torch.nn.Module],
79+
) -> list[tuple[str, torch.nn.Parameter]]:
80+
"""Collect all floating-point model parameters in a deterministic order."""
81+
if isinstance(model, dict):
82+
named_parameters = []
83+
for model_key in sorted(model):
84+
named_parameters.extend(
85+
[
86+
(f"{model_key}.{name}", param)
87+
for name, param in model[model_key].named_parameters()
88+
if torch.is_floating_point(param)
89+
]
90+
)
91+
return named_parameters
92+
return [
93+
(name, param)
94+
for name, param in model.named_parameters()
95+
if torch.is_floating_point(param)
96+
]
97+
98+
def _clone_model_parameters(
99+
self,
100+
model: torch.nn.Module | dict[str, torch.nn.Module],
101+
) -> dict[str, torch.Tensor]:
102+
"""Clone model parameters to initialize the EMA shadow state."""
103+
with torch.no_grad():
104+
return {
105+
name: param.detach().clone()
106+
for name, param in self._named_model_parameters(model)
107+
}
108+
109+
def update(self, model: torch.nn.Module | dict[str, torch.nn.Module]) -> None:
110+
"""Update EMA shadow parameters from the current model parameters."""
111+
with torch.no_grad():
112+
for name, param in self._named_model_parameters(model):
113+
self.shadow_params[name].lerp_(param.detach(), weight=1.0 - self.decay)
114+
115+
def state_dict(self) -> dict[str, Any]:
116+
"""Serialize EMA state for restart."""
117+
return {
118+
EMA_DECAY_KEY: self.decay,
119+
EMA_MODEL_STATE_KEY: {
120+
name: tensor.detach().cpu().clone()
121+
for name, tensor in self.shadow_params.items()
122+
},
123+
EMA_VALIDATION_STATE_KEY: deepcopy(self.validation_state),
124+
}
125+
126+
def load_state_dict(self, state: dict[str, Any]) -> None:
127+
"""Restore EMA shadow parameters and validator state."""
128+
if EMA_DECAY_KEY in state:
129+
checkpoint_decay = float(state[EMA_DECAY_KEY])
130+
if checkpoint_decay != self.decay:
131+
log.warning(
132+
"Ignoring EMA checkpoint decay=%s because training.ema_decay=%s "
133+
"is configured.",
134+
checkpoint_decay,
135+
self.decay,
136+
)
137+
model_state = state.get(EMA_MODEL_STATE_KEY, {})
138+
if not isinstance(model_state, dict):
139+
raise TypeError("EMA checkpoint field `model` must be a dict.")
140+
141+
current_keys = set(self.shadow_params)
142+
loaded_keys = set(model_state)
143+
missing_keys = sorted(current_keys - loaded_keys)
144+
unexpected_keys = sorted(loaded_keys - current_keys)
145+
if missing_keys or unexpected_keys:
146+
raise KeyError(
147+
"EMA checkpoint parameter keys do not match the current model. "
148+
f"Missing keys: {missing_keys[:5]}, unexpected keys: {unexpected_keys[:5]}."
149+
)
150+
151+
with torch.no_grad():
152+
for name, shadow_param in self.shadow_params.items():
153+
loaded_param = model_state[name]
154+
if not isinstance(loaded_param, torch.Tensor):
155+
raise TypeError(
156+
f"EMA checkpoint tensor for {name!r} must be a torch.Tensor."
157+
)
158+
if loaded_param.shape != shadow_param.shape:
159+
raise ValueError(
160+
"EMA checkpoint parameter shape does not match the current "
161+
f"model for {name!r}: expected {tuple(shadow_param.shape)}, "
162+
f"got {tuple(loaded_param.shape)}."
163+
)
164+
shadow_param.copy_(
165+
loaded_param.to(
166+
device=shadow_param.device,
167+
dtype=shadow_param.dtype,
168+
)
169+
)
170+
171+
validation_state = state.get(EMA_VALIDATION_STATE_KEY, {})
172+
if validation_state is None:
173+
validation_state = {}
174+
if not isinstance(validation_state, dict):
175+
raise TypeError("EMA checkpoint field `validation_state` must be a dict.")
176+
self.validation_state = deepcopy(validation_state)
177+
178+
@contextmanager
179+
def apply_shadow(
180+
self,
181+
model: torch.nn.Module | dict[str, torch.nn.Module],
182+
) -> Iterator[None]:
183+
"""Temporarily replace model parameters with the EMA shadow state."""
184+
backups: dict[str, torch.Tensor] = {}
185+
try:
186+
with torch.no_grad():
187+
for name, param in self._named_model_parameters(model):
188+
backups[name] = param.detach().clone()
189+
param.copy_(
190+
self.shadow_params[name].to(
191+
device=param.device,
192+
dtype=param.dtype,
193+
)
194+
)
195+
yield
196+
finally:
197+
with torch.no_grad():
198+
for name, param in self._named_model_parameters(model):
199+
if name in backups:
200+
param.copy_(backups[name])

0 commit comments

Comments
 (0)