Skip to content

Commit c356980

Browse files
committed
use built-in optimizer
1 parent 431aecd commit c356980

3 files changed

Lines changed: 29 additions & 151 deletions

File tree

examples/cosmos/optimizer_utils.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

examples/cosmos/train_cosmos_predict25_lora.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import diffusers
2727
from diffusers import Cosmos2_5_PredictBasePipeline
28-
from diffusers.optimization import get_scheduler
28+
from diffusers.optimization import get_linear_schedule_with_warmup, get_scheduler
2929
from diffusers.training_utils import cast_training_params
3030
from diffusers.utils.torch_utils import is_compiled_module
3131
from diffusers.utils import (
@@ -239,37 +239,26 @@ def parse_args():
239239
parser.add_argument(
240240
"--scheduler_warm_up_steps",
241241
type=int,
242-
nargs="+",
243-
default=[1000],
244-
help="Warm-up steps per cycle for the LambdaLinearScheduler.",
242+
default=1000,
243+
help="Number of warmup steps for the linear LR scheduler.",
245244
)
246245
parser.add_argument(
247-
"--scheduler_cycle_lengths",
246+
"--num_training_steps",
248247
type=int,
249-
nargs="+",
250-
default=[100000],
251-
help="Cycle lengths for the LambdaLinearScheduler.",
252-
)
253-
parser.add_argument(
254-
"--scheduler_f_start",
255-
type=float,
256-
nargs="+",
257-
default=[1e-6],
258-
help="LR multiplier at the start of each warm-up cycle.",
248+
default=100000,
249+
help="Total number of training steps for the LR scheduler.",
259250
)
260251
parser.add_argument(
261252
"--scheduler_f_max",
262253
type=float,
263-
nargs="+",
264-
default=[0.5],
265-
help="Maximum LR multiplier reached after warm-up.",
254+
default=0.5,
255+
help="Maximum LR multiplier (peak after warmup) for the linear scheduler.",
266256
)
267257
parser.add_argument(
268258
"--scheduler_f_min",
269259
type=float,
270-
nargs="+",
271-
default=[0.2],
272-
help="Minimum LR multiplier at the end of each cycle.",
260+
default=0.2,
261+
help="Minimum LR multiplier (floor of linear decay) for the linear scheduler.",
273262
)
274263
parser.add_argument(
275264
"--do_final_eval",
@@ -585,16 +574,13 @@ def main():
585574
if args.allow_tf32:
586575
torch.backends.cuda.matmul.allow_tf32 = True
587576

588-
from optimizer_utils import build_optimizer_and_scheduler
589-
optimizer, lr_scheduler = build_optimizer_and_scheduler(
590-
lora_params,
591-
lr=args.learning_rate,
592-
weight_decay=args.weight_decay,
593-
warm_up_steps=args.scheduler_warm_up_steps,
594-
cycle_lengths=args.scheduler_cycle_lengths,
595-
f_start=args.scheduler_f_start,
596-
f_max=args.scheduler_f_max,
577+
optimizer = torch.optim.AdamW(lora_params, lr=args.learning_rate, weight_decay=args.weight_decay)
578+
lr_scheduler = get_linear_schedule_with_warmup(
579+
optimizer,
580+
num_warmup_steps=args.scheduler_warm_up_steps,
581+
num_training_steps=args.num_training_steps,
597582
f_min=args.scheduler_f_min,
583+
f_max=args.scheduler_f_max,
598584
)
599585

600586
train_dataloader = build_dataloader(args)

src/diffusers/optimization.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ def rule_func(steps: int) -> float:
120120

121121

122122
def get_linear_schedule_with_warmup(
123-
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
123+
optimizer: Optimizer,
124+
num_warmup_steps: int,
125+
num_training_steps: int,
126+
last_epoch: int = -1,
127+
f_min: float = 0.0,
128+
f_max: float = 1.0,
124129
) -> LambdaLR:
125130
"""
126131
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
@@ -135,17 +140,20 @@ def get_linear_schedule_with_warmup(
135140
The total number of training steps.
136141
last_epoch (`int`, *optional*, defaults to -1):
137142
The index of the last epoch when resuming training.
143+
f_min (`float`, *optional*, defaults to 0.0):
144+
Minimum lr multiplier (floor of the linear decay). The lr will not fall below `f_min * initial_lr`.
145+
f_max (`float`, *optional*, defaults to 1.0):
146+
Maximum lr multiplier (peak reached after warmup). The lr peaks at `f_max * initial_lr`.
138147
139148
Return:
140149
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
141150
"""
142151

143152
def lr_lambda(current_step: int):
144153
if current_step < num_warmup_steps:
145-
return float(current_step) / float(max(1, num_warmup_steps))
146-
return max(
147-
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
148-
)
154+
return f_max * float(current_step) / float(max(1, num_warmup_steps))
155+
progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
156+
return f_min + (f_max - f_min) * max(0.0, progress)
149157

150158
return LambdaLR(optimizer, lr_lambda, last_epoch)
151159

0 commit comments

Comments
 (0)