Skip to content

Commit 61f11df

Browse files
committed
feat: add lr_scheduler
1 parent 265c222 commit 61f11df

2 files changed

Lines changed: 58 additions & 1 deletion

File tree

profold2/command/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from profold2.data.utils import (
2424
embedding_get_labels, tensor_to_numpy, weights_from_file
2525
)
26-
from profold2.model import accelerator, FeatureBuilder, MetricDict, ReturnValues
26+
from profold2.model import accelerator, optim, FeatureBuilder, MetricDict, ReturnValues
2727
from profold2.model.utils import CheckpointManager
2828
from profold2.utils import exists
2929

profold2/model/optim.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""lr_scheduler wrapper
2+
"""
3+
from enum import Enum
4+
import math
5+
from typing import Optional, Union
6+
7+
from torch.optim import Optimizer
8+
from torch.optim.lr_scheduler import LambdaLR
9+
10+
from profold2.utils import default, exists
11+
12+
13+
class SchedulerType(Enum):
14+
CONSTANT = 'constant'
15+
COSINE = 'cosine'
16+
LINEAR = 'linear'
17+
18+
19+
def get_scheduler(
20+
name: Union[str, SchedulerType],
21+
optimizer: Optimizer,
22+
num_warmup_steps: Optional[int] = None,
23+
num_training_steps: Optional[int] = None,
24+
eta_min: float = 0.0,
25+
last_epoch: int = -1,
26+
) -> LambdaLR:
27+
name = SchedulerType(name)
28+
29+
if name == SchedulerType.CONSTANT:
30+
31+
def lr_lambda(current_step: int) -> float:
32+
if exists(num_warmup_steps) and current_step < num_warmup_steps:
33+
return current_step / max(1.0, num_warmup_steps)
34+
return 1.0
35+
elif name == SchedulerType.COSINE:
36+
37+
def lr_lambda(current_step: int) -> float:
38+
if exists(num_warmup_steps) and current_step < num_warmup_steps:
39+
return current_step / max(1.0, num_warmup_steps)
40+
num_warmup_steps = default(num_warmup_steps, 0)
41+
progress = (
42+
(current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)
43+
)
44+
return 0.5 * (1.0 - eta_min) * (1.0 + math.cos(math.pi * progress)) + eta_min
45+
elif name == SchedulerType.LINEAR:
46+
47+
def lr_lambda(current_step: int) -> float:
48+
if exists(num_warmup_steps) and current_step < num_warmup_steps:
49+
return current_step / max(1.0, num_warmup_steps)
50+
51+
num_warmup_steps = default(num_warmup_steps, 0)
52+
progress = (
53+
(num_training_steps - current_step) / (num_training_steps - num_warmup_steps)
54+
)
55+
return (1.0 - eta_min) * progress + eta_min
56+
57+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

0 commit comments

Comments
 (0)