Skip to content

Commit 65dc2d9

Browse files
committed
dtype
1 parent 337b334 commit 65dc2d9

1 file changed

Lines changed: 21 additions & 6 deletions

File tree

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from deepmd.dpmodel.array_api import (
1717
Array,
1818
)
19+
from deepmd.env import (
20+
GLOBAL_NP_FLOAT_PRECISION,
21+
)
1922
from deepmd.utils.plugin import (
2023
PluginVariant,
2124
make_plugin_registry,
@@ -177,8 +180,12 @@ def value(self, step: int | Array) -> Array | float:
177180
xp = array_api_compat.array_namespace(step)
178181

179182
# === Step 1. Handle no-warmup case directly ===
180-
# Use input dtype to avoid type mismatch with TensorFlow/PyTorch
181-
step_dtype = step.dtype
183+
# Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers
184+
step_dtype = (
185+
step.dtype
186+
if np.issubdtype(step.dtype, np.floating)
187+
else GLOBAL_NP_FLOAT_PRECISION
188+
)
182189
if self.warmup_steps == 0:
183190
lr = self._decay_value(xp.astype(step, step_dtype))
184191
else:
@@ -367,8 +374,12 @@ def _decay_value(self, step: int | Array) -> Array:
367374
step = np.asarray(step)
368375
xp = array_api_compat.array_namespace(step)
369376
# === Step 1. Compute exponent based on smooth mode ===
370-
# Use input dtype to avoid type mismatch with TensorFlow/PyTorch
371-
step_dtype = step.dtype
377+
# Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers
378+
step_dtype = (
379+
step.dtype
380+
if np.issubdtype(step.dtype, np.floating)
381+
else GLOBAL_NP_FLOAT_PRECISION
382+
)
372383
if self.smooth:
373384
exponent = xp.astype(step, step_dtype) / self.decay_steps
374385
else:
@@ -479,8 +490,12 @@ def _decay_value(self, step: int | Array) -> Array:
479490
step = np.asarray(step)
480491
xp = array_api_compat.array_namespace(step)
481492
min_lr = self._start_lr * self.lr_min_factor
482-
# Use input dtype to avoid type mismatch with TensorFlow/PyTorch
483-
step_dtype = step.dtype
493+
# Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers
494+
step_dtype = (
495+
step.dtype
496+
if np.issubdtype(step.dtype, np.floating)
497+
else GLOBAL_NP_FLOAT_PRECISION
498+
)
484499
# Handle decay_num_steps=0 (no training steps) - return start_lr
485500
if self.decay_num_steps == 0:
486501
return xp.full_like(step, self._start_lr, dtype=step_dtype)

0 commit comments

Comments
 (0)