Skip to content

Commit 1bbd620

Browse files
committed
refactor: code reformat
1 parent e78cb5f commit 1bbd620

2 files changed

Lines changed: 6 additions & 11 deletions

File tree

profold2/command/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ def _step(data_loader, it, writer, stage='train', batch_callback=None):
343343
optimizer.zero_grad(set_to_none=True)
344344

345345
logging.debug(
346-
'_step it: %d, loss_scaler: %f, lr: %s', it, loss_scaler, scheduler.get_lr()
346+
'_step it: %d, loss_scaler: %f, lr: %s', it, loss_scaler,
347+
scheduler.get_last_lr()
347348
)
348349

349350
running_loss = MetricDict()

profold2/model/optim.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,14 @@ def get_scheduler(
2929

3030
if name == SchedulerType.CONSTANT:
3131

32-
def lr_lambda(
33-
current_step: int, num_warmup_steps: Optional[int] = None
34-
) -> float:
32+
def lr_lambda(current_step: int, num_warmup_steps: Optional[int] = None) -> float:
3533
current_step = current_step + last_global_step
3634
if exists(num_warmup_steps) and current_step < num_warmup_steps:
3735
return current_step / max(1.0, num_warmup_steps)
3836
return 1.0
3937
elif name == SchedulerType.COSINE:
4038

41-
def lr_lambda(
42-
current_step: int, num_warmup_steps: Optional[int] = None
43-
) -> float:
39+
def lr_lambda(current_step: int, num_warmup_steps: Optional[int] = None) -> float:
4440
current_step = current_step + last_global_step
4541
if exists(num_warmup_steps) and current_step < num_warmup_steps:
4642
return current_step / max(1.0, num_warmup_steps)
@@ -50,12 +46,10 @@ def lr_lambda(
5046
progress = (
5147
(current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)
5248
)
53-
return 0.5 * (1.0 - eta_min) * (1.0 + math.cos(math.pi * progress)) + eta_min
49+
return 0.5 * (1.0 - eta_min) * (1.0 + math.cos(math.pi * progress)) + eta_min
5450
elif name == SchedulerType.LINEAR:
5551

56-
def lr_lambda(
57-
current_step: int, num_warmup_steps: Optional[int] = None
58-
) -> float:
52+
def lr_lambda(current_step: int, num_warmup_steps: Optional[int] = None) -> float:
5953
current_step = current_step + last_global_step
6054
if exists(num_warmup_steps) and current_step < num_warmup_steps:
6155
return current_step / max(1.0, num_warmup_steps)

0 commit comments

Comments
 (0)