-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlr_classes.py
More file actions
46 lines (33 loc) · 1.27 KB
/
lr_classes.py
File metadata and controls
46 lines (33 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""
Learning rate functions.
Define your own here if necessary. But don't go too crazy. In most cases even constant rate is suffice
"""
from abc import abstractmethod
class LRFunc(object):
@abstractmethod
def __init__(self, *args, **kwargs):
pass
@abstractmethod
def get_lr(self, *args, **kwargs):
pass
class LRFuncConstant(LRFunc):
def __init__(self, lr, *args, **kwargs):
self.lr = lr
super(LRFuncConstant, self).__init__(args, kwargs)
def get_lr(self, *args, **kwargs):
return self.lr
class LRFuncExpDecay(LRFunc):
def __init__(self, start_lr, finish_lr, decay_steps, *args, **kwargs):
assert 0 < finish_lr <= start_lr, "start_lr must be >= finish_lr and both must be positive"
assert decay_steps > 0, "decay_steps must be positive"
self.starter_learning_rate = start_lr
self.finish_learning_rate = finish_lr
self.decay_rate = finish_lr / start_lr
self.decay_steps = decay_steps
super(LRFuncExpDecay, self).__init__(args, kwargs)
def get_lr(self, global_step, *args, **kwargs):
return self.starter_learning_rate * pow(self.decay_rate, (global_step / self.decay_steps))
lrfunc_classes = {
'constant': LRFuncConstant,
'expdecay': LRFuncExpDecay
}