|
25 | 25 |
|
26 | 26 | import diffusers |
27 | 27 | from diffusers import Cosmos2_5_PredictBasePipeline |
28 | | -from diffusers.optimization import get_scheduler |
| 28 | +from diffusers.optimization import get_linear_schedule_with_warmup, get_scheduler |
29 | 29 | from diffusers.training_utils import cast_training_params |
30 | 30 | from diffusers.utils.torch_utils import is_compiled_module |
31 | 31 | from diffusers.utils import ( |
@@ -239,37 +239,26 @@ def parse_args(): |
239 | 239 | parser.add_argument( |
240 | 240 | "--scheduler_warm_up_steps", |
241 | 241 | type=int, |
242 | | - nargs="+", |
243 | | - default=[1000], |
244 | | - help="Warm-up steps per cycle for the LambdaLinearScheduler.", |
| 242 | + default=1000, |
| 243 | + help="Number of warmup steps for the linear LR scheduler.", |
245 | 244 | ) |
246 | 245 | parser.add_argument( |
247 | | - "--scheduler_cycle_lengths", |
| 246 | + "--num_training_steps", |
248 | 247 | type=int, |
249 | | - nargs="+", |
250 | | - default=[100000], |
251 | | - help="Cycle lengths for the LambdaLinearScheduler.", |
252 | | - ) |
253 | | - parser.add_argument( |
254 | | - "--scheduler_f_start", |
255 | | - type=float, |
256 | | - nargs="+", |
257 | | - default=[1e-6], |
258 | | - help="LR multiplier at the start of each warm-up cycle.", |
| 248 | + default=100000, |
| 249 | + help="Total number of training steps for the LR scheduler.", |
259 | 250 | ) |
260 | 251 | parser.add_argument( |
261 | 252 | "--scheduler_f_max", |
262 | 253 | type=float, |
263 | | - nargs="+", |
264 | | - default=[0.5], |
265 | | - help="Maximum LR multiplier reached after warm-up.", |
| 254 | + default=0.5, |
| 255 | + help="Maximum LR multiplier (peak after warmup) for the linear scheduler.", |
266 | 256 | ) |
267 | 257 | parser.add_argument( |
268 | 258 | "--scheduler_f_min", |
269 | 259 | type=float, |
270 | | - nargs="+", |
271 | | - default=[0.2], |
272 | | - help="Minimum LR multiplier at the end of each cycle.", |
| 260 | + default=0.2, |
| 261 | + help="Minimum LR multiplier (floor of linear decay) for the linear scheduler.", |
273 | 262 | ) |
274 | 263 | parser.add_argument( |
275 | 264 | "--do_final_eval", |
@@ -585,16 +574,13 @@ def main(): |
585 | 574 | if args.allow_tf32: |
586 | 575 | torch.backends.cuda.matmul.allow_tf32 = True |
587 | 576 |
|
588 | | - from optimizer_utils import build_optimizer_and_scheduler |
589 | | - optimizer, lr_scheduler = build_optimizer_and_scheduler( |
590 | | - lora_params, |
591 | | - lr=args.learning_rate, |
592 | | - weight_decay=args.weight_decay, |
593 | | - warm_up_steps=args.scheduler_warm_up_steps, |
594 | | - cycle_lengths=args.scheduler_cycle_lengths, |
595 | | - f_start=args.scheduler_f_start, |
596 | | - f_max=args.scheduler_f_max, |
| 577 | + optimizer = torch.optim.AdamW(lora_params, lr=args.learning_rate, weight_decay=args.weight_decay) |
| 578 | + lr_scheduler = get_linear_schedule_with_warmup( |
| 579 | + optimizer, |
| 580 | + num_warmup_steps=args.scheduler_warm_up_steps, |
| 581 | + num_training_steps=args.num_training_steps, |
597 | 582 | f_min=args.scheduler_f_min, |
| 583 | + f_max=args.scheduler_f_max, |
598 | 584 | ) |
599 | 585 |
|
600 | 586 | train_dataloader = build_dataloader(args) |
|
0 commit comments