|
7 | 7 | Any, |
8 | 8 | ) |
9 | 9 |
|
| 10 | +import array_api_compat |
10 | 11 | import numpy as np |
11 | 12 |
|
12 | 13 | from deepmd.common import ( |
13 | 14 | j_get_type, |
14 | 15 | ) |
| 16 | +from deepmd.dpmodel.array_api import ( |
| 17 | + Array, |
| 18 | +) |
15 | 19 | from deepmd.utils.plugin import ( |
16 | 20 | PluginVariant, |
17 | 21 | make_plugin_registry, |
@@ -44,8 +48,9 @@ def __init__( |
44 | 48 | self.stop_steps = stop_steps |
45 | 49 |
|
46 | 50 | @abstractmethod |
47 | | - def value(self, step: int) -> np.float64: |
| 51 | + def value(self, step: int | Array) -> Array: |
48 | 52 | """Get the learning rate at the given step.""" |
| 53 | + # in optax, step will be a jnp.ndarray passed in JIT mode |
49 | 54 | pass |
50 | 55 |
|
51 | 56 |
|
@@ -88,16 +93,23 @@ def __init__( |
88 | 93 | self.decay_steps = default_ds |
89 | 94 | self.decay_rate = np.exp( |
90 | 95 | np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps) |
91 | | - ) |
| 96 | + ).item() |
92 | 97 | if decay_rate is not None: |
93 | 98 | self.decay_rate = decay_rate |
94 | 99 | self.min_lr = self.stop_lr |
95 | 100 |
|
96 | | - def value(self, step: int) -> np.float64: |
| 101 | + def value(self, step: int | Array) -> Array: |
97 | 102 | """Get the learning rate at the given step.""" |
98 | | - step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) |
99 | | - if step_lr < self.min_lr: |
100 | | - step_lr = self.min_lr |
| 103 | + if not array_api_compat.is_array_api_obj(step): |
| 104 | + step = np.asarray(step) |
| 105 | + xp = array_api_compat.array_namespace(step) |
| 106 | + step_lr = self.start_lr * xp.pow( |
| 107 | + xp.asarray(self.decay_rate, device=array_api_compat.device(step)), |
| 108 | + xp.astype(step // self.decay_steps, xp.float64), |
| 109 | + ) |
| 110 | + # the original implementation `if step_lr < self.min_lr:` |
| 111 | + # will cause a dynamic graph which is unsupported in JAX JIT |
| 112 | + step_lr = xp.clip(step_lr, self.min_lr, None) |
101 | 113 | return step_lr |
102 | 114 |
|
103 | 115 |
|
@@ -128,12 +140,24 @@ def __init__( |
128 | 140 | super().__init__(start_lr, stop_lr, stop_steps, **kwargs) |
129 | 141 | self.lr_min_factor = stop_lr / start_lr |
130 | 142 |
|
131 | | - def value(self, step: int) -> np.float64: |
132 | | - if step >= self.stop_steps: |
133 | | - return self.start_lr * self.lr_min_factor |
134 | | - return self.start_lr * ( |
| 143 | + def value(self, step: int | Array) -> Array: |
| 144 | + if not array_api_compat.is_array_api_obj(step): |
| 145 | + step = np.asarray(step) |
| 146 | + xp = array_api_compat.array_namespace(step) |
| 147 | + min_lr = self.start_lr * self.lr_min_factor |
| 148 | + step_lr = self.start_lr * ( |
135 | 149 | self.lr_min_factor |
136 | 150 | + 0.5 |
137 | 151 | * (1 - self.lr_min_factor) |
138 | | - * (1 + np.cos(np.pi * (step / self.stop_steps))) |
| 152 | + * ( |
| 153 | + 1 |
| 154 | + + xp.cos( |
| 155 | + xp.asarray( |
| 156 | + xp.pi * (xp.astype(step, xp.float64) / self.stop_steps), |
| 157 | + device=array_api_compat.device(step), |
| 158 | + ) |
| 159 | + ) |
| 160 | + ) |
139 | 161 | ) |
| 162 | + step_lr = xp.where(step >= self.stop_steps, min_lr, step_lr) |
| 163 | + return step_lr |
0 commit comments