Skip to content

Commit 40d3065

Browse files
committed
support optimizer for LambdaDecay
1 parent 02ef364 commit 40d3065

2 files changed

Lines changed: 37 additions & 0 deletions

File tree

python/paddle/optimizer/lr.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,25 @@ class LambdaDecay(LRScheduler):
15091509

15101510
lr_lambda: Callable[[int], float]
15111511

1512+
@overload
1513+
def __init__(
1514+
self,
1515+
learning_rate: float,
1516+
lr_lambda: Callable[[int], float],
1517+
last_epoch: int = -1,
1518+
verbose: bool = False,
1519+
): ...
1520+
1521+
@overload
1522+
def __init__(
1523+
self,
1524+
optimizer: paddle.optimizer.Optimizer,
1525+
lr_lambda: Callable[[int], float],
1526+
last_epoch: int = -1,
1527+
verbose: bool = False,
1528+
): ...
1529+
1530+
@lr_scheduler_decorator()
15121531
def __init__(
15131532
self,
15141533
learning_rate: float,

test/legacy_test/test_lr_scheduler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,24 @@ def test_step_decay(self):
14631463
)
14641464
paddle.enable_static()
14651465

1466+
def test_lambda_decay(self):
1467+
paddle.disable_static()
1468+
linear = paddle.nn.Linear(10, 10)
1469+
base_lr = 0.5
1470+
lr_lambda = lambda epoch: 0.95**epoch
1471+
sgd = paddle.optimizer.SGD(
1472+
learning_rate=base_lr, parameters=linear.parameters()
1473+
)
1474+
scheduler = paddle.optimizer.lr.LambdaDecay(
1475+
optimizer=sgd, lr_lambda=lr_lambda
1476+
)
1477+
self.assertEqual(scheduler.base_lr, sgd.get_lr())
1478+
self.assertIs(sgd._learning_rate, scheduler)
1479+
lrs = self._test_network(linear, sgd, scheduler)
1480+
for i in range(len(lrs)):
1481+
np.testing.assert_allclose(lrs[i], base_lr * lr_lambda(i))
1482+
paddle.enable_static()
1483+
14661484

14671485
if __name__ == '__main__':
14681486
paddle.enable_static()

0 commit comments

Comments
 (0)