Skip to content

Commit 6052513

Browse files
author
Pooya Moradi
committed
get_optimizer: respect learning_rate_schedule_steps config knob
base.yml documents learning_rate_schedule_steps as the LR schedule shape control ("By default the length of the schedule is set to the number of steps", but configurable to a longer/different value). The post_train RL get_optimizer ignored this knob and always used max_train_steps directly, silently dropping any non-default value. This matters for GPU<->TPU recipe parity: when reproducing a GPU recipe with NUM_BATCHES different from the GPU's, you need to keep the LR schedule SHAPE the same (e.g., warmup=50, decay=500 like NeMo-RL's lr_warmup_iters/lr_decay_iters) regardless of how many TPU steps you run. Without this fix, integrated LR scales linearly with NUM_BATCHES. Backward-compatible: default learning_rate_schedule_steps=-1 (or unset) falls back to max_train_steps, identical to old behavior.
1 parent 493fba6 commit 6052513

1 file changed

Lines changed: 18 additions & 5 deletions

File tree

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,15 +531,28 @@ def check_correctness(extracted_response: str, acceptable_answers: list[str], tm
531531

532532

533533
def get_optimizer(tmvp_config: Any, max_train_steps: int) -> optax.GradientTransformation:
534-
"""Function to obtain an optax optimizer, currently we use adamw."""
534+
"""Function to obtain an optax optimizer, currently we use adamw.
535+
536+
Schedule shape is controlled by `learning_rate_schedule_steps` when set
537+
(>0); this decouples warmup/decay shape from training length so the same
538+
schedule can be applied across runs of different num_batches. Default
539+
(-1) falls back to `max_train_steps` for backward compatibility — matches
540+
the documented behavior of base.yml's `learning_rate_schedule_steps: -1`
541+
("By default the length of the schedule is set to the number of steps").
542+
"""
543+
schedule_steps = getattr(tmvp_config, "learning_rate_schedule_steps", -1)
544+
if schedule_steps is None or schedule_steps <= 0:
545+
schedule_steps = max_train_steps
535546
schedule = optax.schedules.warmup_cosine_decay_schedule(
536547
init_value=0.0,
537548
peak_value=tmvp_config.learning_rate,
538549
# Linearly increase learning rate from 0. to learning_rate in the first
539-
# warmup_steps_fraction training steps, and then gradually decrease the
540-
# learning rate to 0 using cosine scheduler.
541-
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
542-
decay_steps=max_train_steps,
550+
# warmup_steps_fraction × schedule_steps steps, then cosine-decay to 0
551+
# over the remaining schedule_steps. When schedule_steps > max_train_steps
552+
# the run ends partway through the schedule (useful for matching a fixed
553+
# GPU LR schedule across TPU runs with different num_batches).
554+
warmup_steps=int(tmvp_config.warmup_steps_fraction * schedule_steps),
555+
decay_steps=schedule_steps,
543556
end_value=0.0,
544557
)
545558

0 commit comments

Comments
 (0)