|
16 | 16 | from deepmd.dpmodel.array_api import ( |
17 | 17 | Array, |
18 | 18 | ) |
| 19 | +from deepmd.env import ( |
| 20 | + GLOBAL_NP_FLOAT_PRECISION, |
| 21 | +) |
19 | 22 | from deepmd.utils.plugin import ( |
20 | 23 | PluginVariant, |
21 | 24 | make_plugin_registry, |
@@ -177,8 +180,12 @@ def value(self, step: int | Array) -> Array | float: |
177 | 180 | xp = array_api_compat.array_namespace(step) |
178 | 181 |
|
179 | 182 | # === Step 1. Handle no-warmup case directly === |
180 | | - # Use input dtype to avoid type mismatch with TensorFlow/PyTorch |
181 | | - step_dtype = step.dtype |
| 183 | + # Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers |
| 184 | + step_dtype = ( |
| 185 | + step.dtype |
| 186 | + if np.issubdtype(step.dtype, np.floating) |
| 187 | + else GLOBAL_NP_FLOAT_PRECISION |
| 188 | + ) |
182 | 189 | if self.warmup_steps == 0: |
183 | 190 | lr = self._decay_value(xp.astype(step, step_dtype)) |
184 | 191 | else: |
@@ -367,8 +374,12 @@ def _decay_value(self, step: int | Array) -> Array: |
367 | 374 | step = np.asarray(step) |
368 | 375 | xp = array_api_compat.array_namespace(step) |
369 | 376 | # === Step 1. Compute exponent based on smooth mode === |
370 | | - # Use input dtype to avoid type mismatch with TensorFlow/PyTorch |
371 | | - step_dtype = step.dtype |
| 377 | + # Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers |
| 378 | + step_dtype = ( |
| 379 | + step.dtype |
| 380 | + if np.issubdtype(step.dtype, np.floating) |
| 381 | + else GLOBAL_NP_FLOAT_PRECISION |
| 382 | + ) |
372 | 383 | if self.smooth: |
373 | 384 | exponent = xp.astype(step, step_dtype) / self.decay_steps |
374 | 385 | else: |
@@ -479,8 +490,12 @@ def _decay_value(self, step: int | Array) -> Array: |
479 | 490 | step = np.asarray(step) |
480 | 491 | xp = array_api_compat.array_namespace(step) |
481 | 492 | min_lr = self._start_lr * self.lr_min_factor |
482 | | - # Use input dtype to avoid type mismatch with TensorFlow/PyTorch |
483 | | - step_dtype = step.dtype |
| 493 | + # Use input dtype for floating point, or default to GLOBAL_NP_FLOAT_PRECISION for integers |
| 494 | + step_dtype = ( |
| 495 | + step.dtype |
| 496 | + if np.issubdtype(step.dtype, np.floating) |
| 497 | + else GLOBAL_NP_FLOAT_PRECISION |
| 498 | + ) |
484 | 499 | # Handle decay_num_steps=0 (no training steps) - return start_lr |
485 | 500 | if self.decay_num_steps == 0: |
486 | 501 | return xp.full_like(step, self._start_lr, dtype=step_dtype) |
|
0 commit comments