Skip to content

Commit 947e8a6

Browse files
committed
fix
1 parent 2ee5021 commit 947e8a6

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

deepmd/tf/utils/learning_rate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
BaseLR,
1414
)
1515
from deepmd.tf.env import (
16+
GLOBAL_TF_FLOAT_PRECISION,
1617
tf,
1718
)
1819

@@ -93,14 +94,15 @@ def build(self, global_step: tf.Tensor, num_steps: int) -> tf.Tensor:
9394
base_lr = self._base_lr
9495

9596
def _lr_value(step: np.ndarray) -> np.ndarray:
96-
# Use float32 for learning rate, consistent with PyTorch/Paddle backends
97+
# Use GLOBAL_TF_FLOAT_PRECISION (float64) for learning rate,
98+
# consistent with energy precision in TF backend
9799
return np.asarray(
98100
base_lr.value(step),
99-
dtype=np.float32,
101+
dtype=GLOBAL_TF_FLOAT_PRECISION.as_numpy_dtype,
100102
)
101103

102104
lr = tf.numpy_function(
103-
_lr_value, [global_step], Tout=tf.float32, name="lr_schedule"
105+
_lr_value, [global_step], Tout=GLOBAL_TF_FLOAT_PRECISION, name="lr_schedule"
104106
)
105107
lr.set_shape(global_step.get_shape())
106108
return lr

source/tests/tf/test_lr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ class TestLearningRateScheduleBuild(unittest.TestCase):
4242
"""Test TF tensor building and integration."""
4343

4444
def test_build_returns_tensor(self) -> None:
45-
"""Test that build() returns a float32 TF tensor (consistent with PT/PD backends)."""
45+
"""Test that build() returns a float64 TF tensor (consistent with GLOBAL_TF_FLOAT_PRECISION)."""
4646
lr_schedule = LearningRateSchedule({"start_lr": 1e-3, "stop_lr": 1e-5})
4747
global_step = tf.constant(0, dtype=tf.int64)
4848
lr_tensor = lr_schedule.build(global_step, num_steps=10000)
4949

5050
self.assertIsInstance(lr_tensor, tf.Tensor)
51-
self.assertEqual(lr_tensor.dtype, tf.float32)
51+
self.assertEqual(lr_tensor.dtype, tf.float64)
5252

5353
def test_default_type_exp(self) -> None:
5454
"""Test that default type is 'exp' when not specified."""

0 commit comments

Comments
 (0)