|
3 | 3 | ABC, |
4 | 4 | abstractmethod, |
5 | 5 | ) |
| 6 | +from types import ModuleType |
6 | 7 | from typing import ( |
7 | 8 | Any, |
| 9 | + overload, |
| 10 | + override, |
8 | 11 | ) |
9 | 12 |
|
| 13 | +import array_api_compat |
| 14 | +import array_api_compat |
10 | 15 | import numpy as np |
11 | 16 |
|
12 | 17 | from deepmd.common import ( |
@@ -45,10 +50,26 @@ def __init__( |
45 | 50 | self.stop_steps = stop_steps |
46 | 51 |
|
47 | 52 | @abstractmethod |
48 | | - def value(self, step: int, xp: Any = np) -> Array: |
| 53 | + def value(self, step: int | Array) -> Array: |
49 | 54 | """Get the learning rate at the given step.""" |
| 55 | + # in optax, step will be a jnp.ndarray passed in JIT mode |
50 | 56 | pass |
51 | 57 |
|
| 58 | + @overload |
| 59 | + def array_namespace(self, step: int) -> ModuleType: ... |
| 60 | + @overload |
| 61 | + def array_namespace(self, step: Array) -> Any: ... |
| 62 | + |
| 63 | + def array_namespace(self, step: int | Array) -> Any: |
| 64 | + """Get the array API namespace based on the type of step. |
| 65 | +
|
| 66 | + If the step is int, use NumPy. |
| 67 | + """ |
| 68 | + if array_api_compat.is_array_api_obj(step): |
| 69 | + xp = array_api_compat.array_namespace(step) |
| 70 | + return xp |
| 71 | + return np |
| 72 | + |
52 | 73 |
|
53 | 74 | @BaseLR.register("exp") |
54 | 75 | class LearningRateExp(BaseLR): |
@@ -94,8 +115,9 @@ def __init__( |
94 | 115 | self.decay_rate = decay_rate |
95 | 116 | self.min_lr = self.stop_lr |
96 | 117 |
|
97 | | - def value(self, step: int, xp: Any = np) -> Array: |
| 118 | + def value(self, step: int | Array) -> Array: |
98 | 119 | """Get the learning rate at the given step.""" |
| 120 | + xp = self.array_namespace(step) |
99 | 121 | step_lr = self.start_lr * xp.pow( |
100 | 122 | xp.asarray(self.decay_rate), step // self.decay_steps |
101 | 123 | ) |
@@ -132,7 +154,8 @@ def __init__( |
132 | 154 | super().__init__(start_lr, stop_lr, stop_steps, **kwargs) |
133 | 155 | self.lr_min_factor = stop_lr / start_lr |
134 | 156 |
|
135 | | - def value(self, step: int, xp: Any = np) -> Array: |
| 157 | + def value(self, step: int | Array) -> Array: |
| 158 | + xp = self.array_namespace(step) |
136 | 159 | min_lr = self.start_lr * self.lr_min_factor |
137 | 160 | step_lr = self.start_lr * ( |
138 | 161 | self.lr_min_factor |
|
0 commit comments