Skip to content

Commit 689d357

Browse files
committed
fix
1 parent f271bf8 commit 689d357

File tree

2 files changed

+15
-22
lines changed

2 files changed

+15
-22
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -511,18 +511,10 @@ def __init__(
511511
# === Derive stable and decay phase lengths ===
512512
self.decay_phase_ratio = decay_phase_ratio
513513
self.decay_type = decay_type
514-
self.decay_phase_steps = int(self.decay_phase_ratio * self.num_steps)
515-
if self.decay_phase_steps <= 0:
516-
raise ValueError(
517-
"decay_phase_ratio results in zero decay steps. "
518-
"Increase num_steps or decay_phase_ratio."
519-
)
520-
if self.decay_phase_steps > self.decay_num_steps:
521-
raise ValueError(
522-
"decay phase steps must not exceed the post-warmup steps. "
523-
f"Got decay_phase_steps={self.decay_phase_steps}, "
524-
f"post_warmup_steps={self.decay_num_steps}."
525-
)
514+
# Clamp decay_phase_steps to valid range [1, decay_num_steps]
515+
self.decay_phase_steps = max(
516+
1, min(int(self.decay_phase_ratio * self.num_steps), self.decay_num_steps)
517+
)
526518
self.stable_steps = self.decay_num_steps - self.decay_phase_steps
527519

528520
def _decay_value(self, step: int | Array) -> Array:
@@ -556,7 +548,6 @@ def _decay_value(self, step: int | Array) -> Array:
556548
stop_lr = xp.asarray(self.stop_lr, dtype=step_dtype)
557549
stable_steps = xp.asarray(self.stable_steps, dtype=step_dtype)
558550
decay_phase_steps = xp.asarray(self.decay_phase_steps, dtype=step_dtype)
559-
decay_num_steps = xp.asarray(self.decay_num_steps, dtype=step_dtype)
560551

561552
# === Step 2. Keep a constant learning rate in the stable phase ===
562553
decay_progress = (typed_step - stable_steps) / decay_phase_steps

source/tests/universal/dpmodel/utils/test_learning_rate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,17 @@ def test_invalid_decay_phase_ratio(self) -> None:
164164
)
165165

166166
def test_decay_phase_exceeds_post_warmup_steps(self) -> None:
167-
"""Test WSD rejects decay phases longer than post-warmup steps."""
168-
with self.assertRaises(ValueError):
169-
LearningRateWSD(
170-
start_lr=1e-3,
171-
stop_lr=1e-5,
172-
num_steps=10,
173-
warmup_steps=9,
174-
decay_phase_ratio=0.2,
175-
)
167+
"""Test WSD clamps decay_phase_steps to post-warmup steps when ratio is too large."""
168+
lr = LearningRateWSD(
169+
start_lr=1e-3,
170+
stop_lr=1e-5,
171+
num_steps=10,
172+
warmup_steps=9,
173+
decay_phase_ratio=0.2,
174+
)
175+
# decay_num_steps = 1, so decay_phase_steps should be clamped to 1
176+
self.assertEqual(lr.decay_phase_steps, 1)
177+
self.assertEqual(lr.stable_steps, 0)
176178

177179

178180
class TestLearningRateWarmup(unittest.TestCase):

0 commit comments

Comments
 (0)