Skip to content

Commit 8ab8e99

Browse files
committed
improve how to get xp
1 parent a2507f7 commit 8ab8e99

1 file changed

Lines changed: 26 additions & 3 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
ABC,
44
abstractmethod,
55
)
6+
from types import ModuleType
67
from typing import (
78
Any,
9+
overload,
10+
override,
811
)
912

13+
import array_api_compat
14+
import array_api_compat
1015
import numpy as np
1116

1217
from deepmd.common import (
@@ -45,10 +50,26 @@ def __init__(
4550
self.stop_steps = stop_steps
4651

4752
@abstractmethod
48-
def value(self, step: int, xp: Any = np) -> Array:
53+
def value(self, step: int | Array) -> Array:
4954
"""Get the learning rate at the given step."""
55+
# in optax, step will be a jnp.ndarray passed in JIT mode
5056
pass
5157

58+
@overload
59+
def array_namespace(self, step: int) -> ModuleType: ...
60+
@overload
61+
def array_namespace(self, step: Array) -> Any: ...
62+
63+
def array_namespace(self, step: int | Array) -> Any:
64+
"""Get the array API namespace based on the type of step.
65+
66+
If the step is int, use NumPy.
67+
"""
68+
if array_api_compat.is_array_api_obj(step):
69+
xp = array_api_compat.array_namespace(step)
70+
return xp
71+
return np
72+
5273

5374
@BaseLR.register("exp")
5475
class LearningRateExp(BaseLR):
@@ -94,8 +115,9 @@ def __init__(
94115
self.decay_rate = decay_rate
95116
self.min_lr = self.stop_lr
96117

97-
def value(self, step: int, xp: Any = np) -> Array:
118+
def value(self, step: int | Array) -> Array:
98119
"""Get the learning rate at the given step."""
120+
xp = self.array_namespace(step)
99121
step_lr = self.start_lr * xp.pow(
100122
xp.asarray(self.decay_rate), step // self.decay_steps
101123
)
@@ -132,7 +154,8 @@ def __init__(
132154
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
133155
self.lr_min_factor = stop_lr / start_lr
134156

135-
def value(self, step: int, xp: Any = np) -> Array:
157+
def value(self, step: int | Array) -> Array:
158+
xp = self.array_namespace(step)
136159
min_lr = self.start_lr * self.lr_min_factor
137160
step_lr = self.start_lr * (
138161
self.lr_min_factor

0 commit comments

Comments
 (0)