Skip to content

Commit 78bffde

Browse files
fix(pt): base LambdaLR on configured start_lr (deepmodeling#5434)
- Fix PyTorch LR scheduler construction to use the configured `start_lr` as `LambdaLR`'s base LR. - Reset optimizer param group `initial_lr` after loading restart checkpoints so stale checkpoint values do not scale resumed training. - Avoid double-counting `start_step` by letting `last_epoch` handle scheduler resume position. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved internal learning-rate scheduler construction for more robust and maintainable training behavior. * **Bug Fixes** * Added runtime validation to reject invalid starting learning-rate values, preventing training misconfiguration. * **Tests** * Added a resume/restart learning-rate test and a unit test ensuring invalid start values are rejected; fixed a corrupted test. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8e4ea33 commit 78bffde

4 files changed

Lines changed: 103 additions & 14 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def __init__(
7070
The warmup learning rate starts from warmup_start_factor * start_lr.
7171
Default is 0.0.
7272
"""
73-
# === Step 1. Validate stop_lr and stop_lr_ratio (runtime check) ===
73+
# === Step 1. Validate start_lr (runtime check) ===
74+
if start_lr <= 0 or not np.isfinite(start_lr):
75+
raise ValueError(f"start_lr ({start_lr}) must be positive and finite.")
76+
77+
# === Step 2. Validate stop_lr and stop_lr_ratio (runtime check) ===
7478
has_stop_lr = stop_lr is not None
7579
has_stop_lr_ratio = stop_lr_ratio is not None
7680

@@ -85,13 +89,13 @@ def __init__(
8589
"Got stop_lr=None, stop_lr_ratio=None"
8690
)
8791

88-
# === Step 2. Compute stop_lr from stop_lr_ratio if needed ===
92+
# === Step 3. Compute stop_lr from stop_lr_ratio if needed ===
8993
if stop_lr_ratio is not None:
9094
self.stop_lr = start_lr * stop_lr_ratio
9195
else:
9296
self.stop_lr = stop_lr
9397

94-
# === Step 3. Validate warmup_steps and warmup_ratio (runtime check) ===
98+
# === Step 4. Validate warmup_steps and warmup_ratio (runtime check) ===
9599
has_warmup_steps = warmup_steps != 0
96100
has_warmup_ratio = warmup_ratio is not None
97101

@@ -101,13 +105,13 @@ def __init__(
101105
f"Got warmup_steps={warmup_steps}, warmup_ratio={warmup_ratio}"
102106
)
103107

104-
# === Step 4. Compute warmup_steps from warmup_ratio if needed ===
108+
# === Step 5. Compute warmup_steps from warmup_ratio if needed ===
105109
if warmup_ratio is not None:
106110
self.warmup_steps = int(warmup_ratio * num_steps)
107111
else:
108112
self.warmup_steps = warmup_steps
109113

110-
# === Step 5. Validate step ranges (runtime check) ===
114+
# === Step 6. Validate step ranges (runtime check) ===
111115
if num_steps < 0:
112116
raise ValueError("num_steps must be non-negative")
113117
if self.warmup_steps < 0:
@@ -117,10 +121,10 @@ def __init__(
117121
if num_steps == 0 and self.warmup_steps != 0:
118122
raise ValueError("warmup_steps must be 0 when num_steps is 0")
119123

120-
# === Step 6. Compute warmup_start_lr ===
124+
# === Step 7. Compute warmup_start_lr ===
121125
self.warmup_start_lr = warmup_start_factor * start_lr
122126

123-
# === Step 7. Store core parameters ===
127+
# === Step 8. Store core parameters ===
124128
self._start_lr = start_lr
125129
self.num_steps = num_steps
126130
# Decay phase covers (num_steps - warmup_steps) steps
@@ -493,8 +497,6 @@ def __init__(
493497
)
494498

495499
# === Validate WSD-specific invariants ===
496-
if self._start_lr <= 0:
497-
raise ValueError(f"start_lr ({self._start_lr}) must be positive.")
498500
if self.stop_lr <= 0:
499501
raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.")
500502
if decay_phase_ratio <= 0 or decay_phase_ratio > 1:

deepmd/pt/train/training.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -939,12 +939,10 @@ def single_model_finetune(
939939
**extra,
940940
)
941941
self._load_optimizer_state(optimizer_state_dict)
942-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
942+
self.scheduler = self._create_lr_scheduler(
943943
self.optimizer,
944-
lambda step: (
945-
self.lr_schedule.value(step + self.start_step) / initial_lr
946-
),
947-
last_epoch=self.start_step - 1,
944+
self.lr_schedule,
945+
self.start_step,
948946
)
949947

950948
if self.zero_stage > 0 and self.rank == 0:
@@ -975,6 +973,21 @@ def single_model_finetune(
975973
if self.rank == 0:
976974
self._log_parameter_count()
977975

976+
@staticmethod
977+
def _create_lr_scheduler(
978+
optimizer: torch.optim.Optimizer,
979+
lr_schedule: BaseLR,
980+
start_step: int,
981+
) -> torch.optim.lr_scheduler.LambdaLR:
982+
base_lr = float(lr_schedule.start_lr)
983+
for group in optimizer.param_groups:
984+
group["initial_lr"] = base_lr
985+
return torch.optim.lr_scheduler.LambdaLR(
986+
optimizer,
987+
lambda step: lr_schedule.value(step) / base_lr,
988+
last_epoch=start_step - 1,
989+
)
990+
978991
def _create_full_validator(
979992
self,
980993
*,

source/tests/pt/test_training.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,68 @@ def test_fitting_stat_consistency(self) -> None:
771771
)
772772

773773

774+
class TestLearningRateRestart(unittest.TestCase):
775+
def setUp(self) -> None:
776+
self._cwd = os.getcwd()
777+
self._tmpdir = tempfile.TemporaryDirectory()
778+
os.chdir(self._tmpdir.name)
779+
input_json = str(Path(__file__).parent / "water/se_atten.json")
780+
with open(input_json) as f:
781+
self.config = json.load(f)
782+
self.config = convert_optimizer_v31_to_v32(self.config, warning=False)
783+
data_file = [str(Path(__file__).parent / "water/data/data_0")]
784+
self.config["training"]["training_data"]["systems"] = data_file
785+
self.config["training"]["validation_data"]["systems"] = data_file
786+
self.config["model"] = deepcopy(model_se_e2_a)
787+
self.config["learning_rate"] = {
788+
"type": "wsd",
789+
"start_lr": 5e-4,
790+
"stop_lr": 1e-6,
791+
"warmup_steps": 2,
792+
"warmup_start_factor": 0.2,
793+
"decay_phase_ratio": 0.5,
794+
"decay_type": "cosine",
795+
}
796+
self.config["training"]["numb_steps"] = 3
797+
self.config["training"]["save_freq"] = 3
798+
self.config["training"]["disp_freq"] = 1
799+
self.config["training"]["disp_training"] = False
800+
self.config["training"]["time_training"] = False
801+
802+
def tearDown(self) -> None:
803+
os.chdir(self._cwd)
804+
self._tmpdir.cleanup()
805+
806+
def test_restart_scheduler_matches_lr_schedule(self) -> None:
807+
trainer = get_trainer(deepcopy(self.config))
808+
trainer.run()
809+
restart_model = Path("model-3.pt")
810+
checkpoint = torch.load(restart_model, map_location="cpu", weights_only=True)
811+
stale_initial_lr = trainer.lr_schedule.value(0)
812+
for group in checkpoint["optimizer"]["param_groups"]:
813+
group["initial_lr"] = stale_initial_lr
814+
torch.save(checkpoint, restart_model)
815+
816+
restart_config = deepcopy(self.config)
817+
restart_config["training"]["numb_steps"] = 5
818+
restart_trainer = get_trainer(
819+
restart_config,
820+
restart_model=str(restart_model),
821+
)
822+
823+
np.testing.assert_allclose(
824+
restart_trainer.scheduler.get_last_lr()[0],
825+
restart_trainer.lr_schedule.value(restart_trainer.start_step),
826+
rtol=1e-12,
827+
)
828+
restart_trainer.run()
829+
np.testing.assert_allclose(
830+
restart_trainer.scheduler.get_last_lr()[0],
831+
restart_trainer.lr_schedule.value(restart_config["training"]["numb_steps"]),
832+
rtol=1e-12,
833+
)
834+
835+
774836
class TestFullValidation(unittest.TestCase):
775837
def setUp(self) -> None:
776838
self._cwd = os.getcwd()

source/tests/universal/dpmodel/utils/test_learning_rate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ def test_decay_rate_override(self) -> None:
5050
self.assertEqual(lr.decay_rate, 0.9)
5151
np.testing.assert_allclose(lr.value(1000), 1e-3 * 0.9, rtol=1e-10)
5252

53+
def test_rejects_nonpositive_or_nonfinite_start_lr(self) -> None:
54+
"""Test invalid start_lr values are rejected by the base schedule."""
55+
for start_lr in (0.0, -1e-3, np.inf, np.nan):
56+
with self.subTest(start_lr=start_lr):
57+
with self.assertRaisesRegex(ValueError, "start_lr"):
58+
LearningRateExp(
59+
start_lr=start_lr,
60+
stop_lr=1e-5,
61+
num_steps=10000,
62+
decay_steps=5000,
63+
)
64+
5365

5466
class TestLearningRateCosineBasic(unittest.TestCase):
5567
"""Test basic cosine annealing learning rate functionality."""

0 commit comments

Comments
 (0)