Skip to content

Commit 850f63d

Browse files
committed
fix: use get_xp_precision
1 parent 5b58041 commit 850f63d

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from deepmd.dpmodel.array_api import (
1717
Array,
1818
)
19-
from deepmd.env import (
20-
GLOBAL_NP_FLOAT_PRECISION,
19+
from deepmd.dpmodel.common import (
20+
get_xp_precision,
2121
)
2222
from deepmd.utils.plugin import (
2323
PluginVariant,
@@ -183,8 +183,8 @@ def value(self, step: int | Array) -> Array | float:
183183
# Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers
184184
step_dtype = (
185185
step.dtype
186-
if np.issubdtype(step.dtype, np.floating)
187-
else GLOBAL_NP_FLOAT_PRECISION
186+
if xp.isdtype(step.dtype, "real floating")
187+
else get_xp_precision(xp, "global")
188188
)
189189
if self.warmup_steps == 0:
190190
lr = self._decay_value(xp.astype(step, step_dtype))
@@ -377,8 +377,8 @@ def _decay_value(self, step: int | Array) -> Array:
377377
# Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers
378378
step_dtype = (
379379
step.dtype
380-
if np.issubdtype(step.dtype, np.floating)
381-
else GLOBAL_NP_FLOAT_PRECISION
380+
if xp.isdtype(step.dtype, "real floating")
381+
else get_xp_precision(xp, "global")
382382
)
383383
if self.smooth:
384384
exponent = xp.astype(step, step_dtype) / self.decay_steps
@@ -493,8 +493,8 @@ def _decay_value(self, step: int | Array) -> Array:
493493
# Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers
494494
step_dtype = (
495495
step.dtype
496-
if np.issubdtype(step.dtype, np.floating)
497-
else GLOBAL_NP_FLOAT_PRECISION
496+
if xp.isdtype(step.dtype, "real floating")
497+
else get_xp_precision(xp, "global")
498498
)
499499
# Handle decay_num_steps=0 (no training steps) - return start_lr
500500
if self.decay_num_steps == 0:

0 commit comments

Comments
 (0)