|
16 | 16 | from deepmd.dpmodel.array_api import ( |
17 | 17 | Array, |
18 | 18 | ) |
19 | | -from deepmd.env import ( |
20 | | - GLOBAL_NP_FLOAT_PRECISION, |
| 19 | +from deepmd.dpmodel.common import ( |
| 20 | + get_xp_precision, |
21 | 21 | ) |
22 | 22 | from deepmd.utils.plugin import ( |
23 | 23 | PluginVariant, |
@@ -183,8 +183,8 @@ def value(self, step: int | Array) -> Array | float: |
183 | 183 | # Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers |
184 | 184 | step_dtype = ( |
185 | 185 | step.dtype |
186 | | - if np.issubdtype(step.dtype, np.floating) |
187 | | - else GLOBAL_NP_FLOAT_PRECISION |
| 186 | + if xp.isdtype(step.dtype, "real floating") |
| 187 | + else get_xp_precision(xp, "global") |
188 | 188 | ) |
189 | 189 | if self.warmup_steps == 0: |
190 | 190 | lr = self._decay_value(xp.astype(step, step_dtype)) |
@@ -377,8 +377,8 @@ def _decay_value(self, step: int | Array) -> Array: |
377 | 377 | # Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers |
378 | 378 | step_dtype = ( |
379 | 379 | step.dtype |
380 | | - if np.issubdtype(step.dtype, np.floating) |
381 | | - else GLOBAL_NP_FLOAT_PRECISION |
| 380 | + if xp.isdtype(step.dtype, "real floating") |
| 381 | + else get_xp_precision(xp, "global") |
382 | 382 | ) |
383 | 383 | if self.smooth: |
384 | 384 | exponent = xp.astype(step, step_dtype) / self.decay_steps |
@@ -493,8 +493,8 @@ def _decay_value(self, step: int | Array) -> Array: |
493 | 493 | # Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers |
494 | 494 | step_dtype = ( |
495 | 495 | step.dtype |
496 | | - if np.issubdtype(step.dtype, np.floating) |
497 | | - else GLOBAL_NP_FLOAT_PRECISION |
| 496 | + if xp.isdtype(step.dtype, "real floating") |
| 497 | + else get_xp_precision(xp, "global") |
498 | 498 | ) |
499 | 499 | # Handle decay_num_steps=0 (no training steps) - return start_lr |
500 | 500 | if self.decay_num_steps == 0: |
|
0 commit comments