Skip to content

Commit 854dca8

Browse files
feat(dpmodel): support Array API learning rate (#5143)
This is useful when the LR is within the JAX JIT compilation - the step is given as a `jnp.ndarray` and NumPy should not be used. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Learning-rate API now accepts array-like steps and computes schedules using backend-agnostic array operations; exponential and cosine schedulers support array inputs and preserve minimum learning-rate behavior. * **Tests** * Added cross-backend consistency tests to validate identical learning-rate outputs across NumPy, PyTorch, JAX and array-api backends. * **Chores** * Updated test dependency constraint and bumped default test API version used by strict array-api tests. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 82a5f32 commit 854dca8

4 files changed

Lines changed: 122 additions & 13 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
Any,
88
)
99

10+
import array_api_compat
1011
import numpy as np
1112

1213
from deepmd.common import (
1314
j_get_type,
1415
)
16+
from deepmd.dpmodel.array_api import (
17+
Array,
18+
)
1519
from deepmd.utils.plugin import (
1620
PluginVariant,
1721
make_plugin_registry,
@@ -44,8 +48,9 @@ def __init__(
4448
self.stop_steps = stop_steps
4549

4650
@abstractmethod
47-
def value(self, step: int) -> np.float64:
51+
def value(self, step: int | Array) -> Array:
4852
"""Get the learning rate at the given step."""
53+
# in optax, step will be a jnp.ndarray passed in JIT mode
4954
pass
5055

5156

@@ -88,16 +93,23 @@ def __init__(
8893
self.decay_steps = default_ds
8994
self.decay_rate = np.exp(
9095
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
91-
)
96+
).item()
9297
if decay_rate is not None:
9398
self.decay_rate = decay_rate
9499
self.min_lr = self.stop_lr
95100

96-
def value(self, step: int) -> np.float64:
101+
def value(self, step: int | Array) -> Array:
97102
"""Get the learning rate at the given step."""
98-
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
99-
if step_lr < self.min_lr:
100-
step_lr = self.min_lr
103+
if not array_api_compat.is_array_api_obj(step):
104+
step = np.asarray(step)
105+
xp = array_api_compat.array_namespace(step)
106+
step_lr = self.start_lr * xp.pow(
107+
xp.asarray(self.decay_rate, device=array_api_compat.device(step)),
108+
xp.astype(step // self.decay_steps, xp.float64),
109+
)
110+
# the original implementation `if step_lr < self.min_lr:`
111+
# will cause a dynamic graph which is unsupported in JAX JIT
112+
step_lr = xp.clip(step_lr, self.min_lr, None)
101113
return step_lr
102114

103115

@@ -128,12 +140,24 @@ def __init__(
128140
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
129141
self.lr_min_factor = stop_lr / start_lr
130142

131-
def value(self, step: int) -> np.float64:
132-
if step >= self.stop_steps:
133-
return self.start_lr * self.lr_min_factor
134-
return self.start_lr * (
143+
def value(self, step: int | Array) -> Array:
144+
if not array_api_compat.is_array_api_obj(step):
145+
step = np.asarray(step)
146+
xp = array_api_compat.array_namespace(step)
147+
min_lr = self.start_lr * self.lr_min_factor
148+
step_lr = self.start_lr * (
135149
self.lr_min_factor
136150
+ 0.5
137151
* (1 - self.lr_min_factor)
138-
* (1 + np.cos(np.pi * (step / self.stop_steps)))
152+
* (
153+
1
154+
+ xp.cos(
155+
xp.asarray(
156+
xp.pi * (xp.astype(step, xp.float64) / self.stop_steps),
157+
device=array_api_compat.device(step),
158+
)
159+
)
160+
)
139161
)
162+
step_lr = xp.where(step >= self.stop_steps, min_lr, step_lr)
163+
return step_lr

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ test = [
8686
"pytest-sugar",
8787
"pytest-split",
8888
"dpgui",
89-
'array-api-strict>=2,!=2.1.1;python_version>="3.9"',
89+
# to support Array API 2024.12
90+
'array-api-strict>=2.2;python_version>="3.9"',
9091
]
9192
docs = [
9293
"sphinx>=3.1.1",

source/tests/array_api_strict/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55

66
# this is the default version in the latest array_api_strict,
77
# but in old versions it may be 2022.12
8-
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
8+
array_api_strict.set_array_api_strict_flags(api_version="2024.12")
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
7+
from deepmd.dpmodel.array_api import (
8+
Array,
9+
)
10+
from deepmd.dpmodel.common import (
11+
to_numpy_array,
12+
)
13+
from deepmd.dpmodel.utils.learning_rate import (
14+
BaseLR,
15+
)
16+
17+
from .common import (
18+
INSTALLED_ARRAY_API_STRICT,
19+
INSTALLED_JAX,
20+
INSTALLED_PT,
21+
parameterized,
22+
)
23+
24+
if INSTALLED_PT:
25+
from deepmd.pt.utils.utils import (
26+
to_torch_tensor,
27+
)
28+
29+
if INSTALLED_JAX:
30+
from deepmd.jax.env import (
31+
jnp,
32+
)
33+
if INSTALLED_ARRAY_API_STRICT:
34+
import array_api_strict as xp
35+
36+
37+
@parameterized(
38+
(
39+
{
40+
"type": "exp",
41+
"start_lr": 1e-3,
42+
"stop_lr": 1e-8,
43+
"decay_steps": 1000,
44+
"stop_steps": 1000000,
45+
},
46+
{
47+
"type": "cosine",
48+
"start_lr": 1e-3,
49+
"stop_lr": 1e-8,
50+
"decay_steps": 1000,
51+
"stop_steps": 1000000,
52+
},
53+
),
54+
)
55+
class TestLearningRateConsistent(unittest.TestCase):
56+
def setUp(self) -> None:
57+
(lr_param,) = self.param
58+
self.lr = BaseLR(**lr_param)
59+
self.step = 500000
60+
self.ref = self.lr.value(self.step)
61+
62+
def compare_test_with_ref(self, step: Array) -> None:
63+
test = self.lr.value(step)
64+
np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10)
65+
66+
def compare_numpy_with_ref(self, step: Array) -> None:
67+
self.compare_test_with_ref(np.asarray(step))
68+
69+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
70+
def test_pt_consistent_with_ref(self) -> None:
71+
self.compare_test_with_ref(to_torch_tensor(self.step))
72+
73+
@unittest.skipUnless(
74+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
75+
)
76+
@unittest.skipUnless(
77+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
78+
)
79+
def test_array_api_strict(self) -> None:
80+
self.compare_test_with_ref(xp.asarray(self.step))
81+
82+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
83+
def test_jax_consistent_with_ref(self) -> None:
84+
self.compare_test_with_ref(jnp.array(self.step))

0 commit comments

Comments
 (0)