File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1509,6 +1509,25 @@ class LambdaDecay(LRScheduler):
15091509
15101510 lr_lambda : Callable [[int ], float ]
15111511
1512+ @overload
1513+ def __init__ (
1514+ self ,
1515+ learning_rate : float ,
1516+ lr_lambda : Callable [[int ], float ],
1517+ last_epoch : int = - 1 ,
1518+ verbose : bool = False ,
1519+ ): ...
1520+
1521+ @overload
1522+ def __init__ (
1523+ self ,
1524+ optimizer : paddle .optimizer .Optimizer ,
1525+ lr_lambda : Callable [[int ], float ],
1526+ last_epoch : int = - 1 ,
1527+ verbose : bool = False ,
1528+ ): ...
1529+
1530+ @lr_scheduler_decorator ()
15121531 def __init__ (
15131532 self ,
15141533 learning_rate : float ,
Original file line number Diff line number Diff line change @@ -1463,6 +1463,24 @@ def test_step_decay(self):
14631463 )
14641464 paddle .enable_static ()
14651465
1466+ def test_lambda_decay (self ):
1467+ paddle .disable_static ()
1468+ linear = paddle .nn .Linear (10 , 10 )
1469+ base_lr = 0.5
1470+ lr_lambda = lambda epoch : 0.95 ** epoch
1471+ sgd = paddle .optimizer .SGD (
1472+ learning_rate = base_lr , parameters = linear .parameters ()
1473+ )
1474+ scheduler = paddle .optimizer .lr .LambdaDecay (
1475+ optimizer = sgd , lr_lambda = lr_lambda
1476+ )
1477+ self .assertEqual (scheduler .base_lr , sgd .get_lr ())
1478+ self .assertIs (sgd ._learning_rate , scheduler )
1479+ lrs = self ._test_network (linear , sgd , scheduler )
1480+ for i in range (len (lrs )):
1481+ np .testing .assert_allclose (lrs [i ], base_lr * lr_lambda (i ))
1482+ paddle .enable_static ()
1483+
14661484
14671485if __name__ == '__main__' :
14681486 paddle .enable_static ()
You can’t perform that action at this time.
0 commit comments