Skip to content

Commit 9b1a711

Browse files
authored
Merge branch 'master' into 0505_reformat_chg_spin
2 parents f68fc18 + 57f870f commit 9b1a711

21 files changed

Lines changed: 2689 additions & 95 deletions

File tree

.github/workflows/test_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- run: python -m pip install uv
4545
- name: Install Python dependencies
4646
run: |
47-
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax --torch-backend cpu
47+
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu
4848
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
4949
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
5050
- name: Convert models

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4444
if: false # skip as we use nvidia image
4545
- run: python -m pip install -U uv
46-
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]"
46+
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax_gpu
4747
- run: |
4848
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
4949
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu
3232
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3333
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
34-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax
34+
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu
3535
source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310
3636
env:
3737
# Please note that uv has some issues with finding

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def __init__(
7070
The warmup learning rate starts from warmup_start_factor * start_lr.
7171
Default is 0.0.
7272
"""
73-
# === Step 1. Validate stop_lr and stop_lr_ratio (runtime check) ===
73+
# === Step 1. Validate start_lr (runtime check) ===
74+
if start_lr <= 0 or not np.isfinite(start_lr):
75+
raise ValueError(f"start_lr ({start_lr}) must be positive and finite.")
76+
77+
# === Step 2. Validate stop_lr and stop_lr_ratio (runtime check) ===
7478
has_stop_lr = stop_lr is not None
7579
has_stop_lr_ratio = stop_lr_ratio is not None
7680

@@ -85,13 +89,13 @@ def __init__(
8589
"Got stop_lr=None, stop_lr_ratio=None"
8690
)
8791

88-
# === Step 2. Compute stop_lr from stop_lr_ratio if needed ===
92+
# === Step 3. Compute stop_lr from stop_lr_ratio if needed ===
8993
if stop_lr_ratio is not None:
9094
self.stop_lr = start_lr * stop_lr_ratio
9195
else:
9296
self.stop_lr = stop_lr
9397

94-
# === Step 3. Validate warmup_steps and warmup_ratio (runtime check) ===
98+
# === Step 4. Validate warmup_steps and warmup_ratio (runtime check) ===
9599
has_warmup_steps = warmup_steps != 0
96100
has_warmup_ratio = warmup_ratio is not None
97101

@@ -101,13 +105,13 @@ def __init__(
101105
f"Got warmup_steps={warmup_steps}, warmup_ratio={warmup_ratio}"
102106
)
103107

104-
# === Step 4. Compute warmup_steps from warmup_ratio if needed ===
108+
# === Step 5. Compute warmup_steps from warmup_ratio if needed ===
105109
if warmup_ratio is not None:
106110
self.warmup_steps = int(warmup_ratio * num_steps)
107111
else:
108112
self.warmup_steps = warmup_steps
109113

110-
# === Step 5. Validate step ranges (runtime check) ===
114+
# === Step 6. Validate step ranges (runtime check) ===
111115
if num_steps < 0:
112116
raise ValueError("num_steps must be non-negative")
113117
if self.warmup_steps < 0:
@@ -117,10 +121,10 @@ def __init__(
117121
if num_steps == 0 and self.warmup_steps != 0:
118122
raise ValueError("warmup_steps must be 0 when num_steps is 0")
119123

120-
# === Step 6. Compute warmup_start_lr ===
124+
# === Step 7. Compute warmup_start_lr ===
121125
self.warmup_start_lr = warmup_start_factor * start_lr
122126

123-
# === Step 7. Store core parameters ===
127+
# === Step 8. Store core parameters ===
124128
self._start_lr = start_lr
125129
self.num_steps = num_steps
126130
# Decay phase covers (num_steps - warmup_steps) steps
@@ -493,8 +497,6 @@ def __init__(
493497
)
494498

495499
# === Validate WSD-specific invariants ===
496-
if self._start_lr <= 0:
497-
raise ValueError(f"start_lr ({self._start_lr}) must be positive.")
498500
if self.stop_lr <= 0:
499501
raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.")
500502
if decay_phase_ratio <= 0 or decay_phase_ratio > 1:

deepmd/pt/model/descriptor/repflows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ def forward(
496496
a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0)
497497
# set all padding positions to index of 0
498498
# if the a neighbor is real or not is indicated by nlist_mask
499-
nlist[nlist == -1] = 0
500-
a_nlist[a_nlist == -1] = 0
499+
nlist = torch.where(nlist == -1, 0, nlist)
500+
a_nlist = torch.where(a_nlist == -1, 0, a_nlist)
501501

502502
# get node embedding
503503
# [nframes, nloc, tebd_dim]

deepmd/pt/model/descriptor/repformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def forward(
457457

458458
# set all padding positions to index of 0
459459
# if the a neighbor is real or not is indicated by nlist_mask
460-
nlist[nlist == -1] = 0
460+
nlist = torch.where(nlist == -1, 0, nlist)
461461
# nb x nall x ng1
462462
if comm_dict is None:
463463
assert mapping is not None

deepmd/pt/train/ema.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: LGPL-3.0-or-later
3+
4+
from __future__ import (
5+
annotations,
6+
)
7+
8+
import logging
9+
from contextlib import (
10+
contextmanager,
11+
)
12+
from copy import (
13+
deepcopy,
14+
)
15+
from pathlib import (
16+
Path,
17+
)
18+
from typing import (
19+
TYPE_CHECKING,
20+
Any,
21+
)
22+
23+
import torch
24+
25+
if TYPE_CHECKING:
26+
from collections.abc import (
27+
Iterator,
28+
)
29+
30+
EMA_CHECKPOINT_KEY = "ema"
31+
EMA_DECAY_KEY = "decay"
32+
EMA_MODEL_STATE_KEY = "model"
33+
EMA_VALIDATION_STATE_KEY = "validation_state"
34+
35+
log = logging.getLogger(__name__)
36+
37+
38+
def _append_suffix(path_like: str | Path, suffix: str) -> Path:
39+
"""Append a suffix before the final file suffix when present."""
40+
path = Path(path_like)
41+
if path.suffix:
42+
return path.with_name(f"{path.stem}{suffix}{path.suffix}")
43+
return path.with_name(f"{path.name}{suffix}")
44+
45+
46+
def get_ema_checkpoint_prefix(save_ckpt: str | Path) -> str:
47+
"""Derive the EMA checkpoint prefix from the regular checkpoint prefix."""
48+
return str(_append_suffix(save_ckpt, "_ema"))
49+
50+
51+
def get_ema_validation_log_path(full_val_file: str | Path) -> Path:
52+
"""Derive the EMA validation log path from the regular validation log path."""
53+
return _append_suffix(full_val_file, "_ema")
54+
55+
56+
class ModelEMA:
57+
"""Maintain an exponential moving average of model parameters.
58+
59+
This helper assumes DDP/ZeRO-1 style training where every rank owns the
60+
same full, consistently ordered model parameters. It is not a sharded
61+
parameter EMA implementation.
62+
"""
63+
64+
def __init__(
65+
self,
66+
model: torch.nn.Module | dict[str, torch.nn.Module],
67+
decay: float,
68+
state: dict[str, Any] | None = None,
69+
) -> None:
70+
self.decay = float(decay)
71+
self.shadow_params = self._clone_model_parameters(model)
72+
self.validation_state: dict[str, Any] = {}
73+
if state is not None:
74+
self.load_state_dict(state)
75+
76+
@staticmethod
77+
def _named_model_parameters(
78+
model: torch.nn.Module | dict[str, torch.nn.Module],
79+
) -> list[tuple[str, torch.nn.Parameter]]:
80+
"""Collect all floating-point model parameters in a deterministic order."""
81+
if isinstance(model, dict):
82+
named_parameters = []
83+
for model_key in sorted(model):
84+
named_parameters.extend(
85+
[
86+
(f"{model_key}.{name}", param)
87+
for name, param in model[model_key].named_parameters()
88+
if torch.is_floating_point(param)
89+
]
90+
)
91+
return named_parameters
92+
return [
93+
(name, param)
94+
for name, param in model.named_parameters()
95+
if torch.is_floating_point(param)
96+
]
97+
98+
def _clone_model_parameters(
99+
self,
100+
model: torch.nn.Module | dict[str, torch.nn.Module],
101+
) -> dict[str, torch.Tensor]:
102+
"""Clone model parameters to initialize the EMA shadow state."""
103+
with torch.no_grad():
104+
return {
105+
name: param.detach().clone()
106+
for name, param in self._named_model_parameters(model)
107+
}
108+
109+
def update(self, model: torch.nn.Module | dict[str, torch.nn.Module]) -> None:
110+
"""Update EMA shadow parameters from the current model parameters."""
111+
with torch.no_grad():
112+
for name, param in self._named_model_parameters(model):
113+
self.shadow_params[name].lerp_(param.detach(), weight=1.0 - self.decay)
114+
115+
def state_dict(self) -> dict[str, Any]:
116+
"""Serialize EMA state for restart."""
117+
return {
118+
EMA_DECAY_KEY: self.decay,
119+
EMA_MODEL_STATE_KEY: {
120+
name: tensor.detach().cpu().clone()
121+
for name, tensor in self.shadow_params.items()
122+
},
123+
EMA_VALIDATION_STATE_KEY: deepcopy(self.validation_state),
124+
}
125+
126+
def load_state_dict(self, state: dict[str, Any]) -> None:
127+
"""Restore EMA shadow parameters and validator state."""
128+
if EMA_DECAY_KEY in state:
129+
checkpoint_decay = float(state[EMA_DECAY_KEY])
130+
if checkpoint_decay != self.decay:
131+
log.warning(
132+
"Ignoring EMA checkpoint decay=%s because training.ema_decay=%s "
133+
"is configured.",
134+
checkpoint_decay,
135+
self.decay,
136+
)
137+
model_state = state.get(EMA_MODEL_STATE_KEY, {})
138+
if not isinstance(model_state, dict):
139+
raise TypeError("EMA checkpoint field `model` must be a dict.")
140+
141+
current_keys = set(self.shadow_params)
142+
loaded_keys = set(model_state)
143+
missing_keys = sorted(current_keys - loaded_keys)
144+
unexpected_keys = sorted(loaded_keys - current_keys)
145+
if missing_keys or unexpected_keys:
146+
raise KeyError(
147+
"EMA checkpoint parameter keys do not match the current model. "
148+
f"Missing keys: {missing_keys[:5]}, unexpected keys: {unexpected_keys[:5]}."
149+
)
150+
151+
with torch.no_grad():
152+
for name, shadow_param in self.shadow_params.items():
153+
loaded_param = model_state[name]
154+
if not isinstance(loaded_param, torch.Tensor):
155+
raise TypeError(
156+
f"EMA checkpoint tensor for {name!r} must be a torch.Tensor."
157+
)
158+
if loaded_param.shape != shadow_param.shape:
159+
raise ValueError(
160+
"EMA checkpoint parameter shape does not match the current "
161+
f"model for {name!r}: expected {tuple(shadow_param.shape)}, "
162+
f"got {tuple(loaded_param.shape)}."
163+
)
164+
shadow_param.copy_(
165+
loaded_param.to(
166+
device=shadow_param.device,
167+
dtype=shadow_param.dtype,
168+
)
169+
)
170+
171+
validation_state = state.get(EMA_VALIDATION_STATE_KEY, {})
172+
if validation_state is None:
173+
validation_state = {}
174+
if not isinstance(validation_state, dict):
175+
raise TypeError("EMA checkpoint field `validation_state` must be a dict.")
176+
self.validation_state = deepcopy(validation_state)
177+
178+
@contextmanager
179+
def apply_shadow(
180+
self,
181+
model: torch.nn.Module | dict[str, torch.nn.Module],
182+
) -> Iterator[None]:
183+
"""Temporarily replace model parameters with the EMA shadow state."""
184+
backups: dict[str, torch.Tensor] = {}
185+
try:
186+
with torch.no_grad():
187+
for name, param in self._named_model_parameters(model):
188+
backups[name] = param.detach().clone()
189+
param.copy_(
190+
self.shadow_params[name].to(
191+
device=param.device,
192+
dtype=param.dtype,
193+
)
194+
)
195+
yield
196+
finally:
197+
with torch.no_grad():
198+
for name, param in self._named_model_parameters(model):
199+
if name in backups:
200+
param.copy_(backups[name])

0 commit comments

Comments
 (0)