Skip to content

Commit 19eebe9

Browse files
committed
let tunix see LR hyperparam for logging
1 parent 086c50d commit 19eebe9

2 files changed

Lines changed: 64 additions & 19 deletions

File tree

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -460,32 +460,40 @@ def extract_hash_answer(text: str) -> str | None:
460460

461461
def get_optimizer(tmvp_config, max_train_steps):
462462
"""Function to obtain an optax optimizer, currently we use adamw."""
463-
optimizer = optax.adamw(
464-
learning_rate=optax.schedules.warmup_cosine_decay_schedule(
465-
init_value=0.0,
466-
peak_value=tmvp_config.learning_rate,
467-
# Linearly increase learning rate from 0. to learning_rate in the first
468-
# warmup_steps_fraction training steps, and then gradually decrease the
469-
# learning rate to 0 using cosine scheduler.
470-
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
471-
decay_steps=max_train_steps,
472-
end_value=0.0,
473-
),
474-
b1=tmvp_config.adam_b1,
475-
b2=tmvp_config.adam_b2,
476-
weight_decay=tmvp_config.adam_weight_decay,
463+
schedule = optax.schedules.warmup_cosine_decay_schedule(
464+
init_value=0.0,
465+
peak_value=tmvp_config.learning_rate,
466+
# Linearly increase learning rate from 0. to learning_rate in the first
467+
# warmup_steps_fraction training steps, and then gradually decrease the
468+
# learning rate to 0 using cosine scheduler.
469+
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
470+
decay_steps=max_train_steps,
471+
end_value=0.0,
477472
)
478473

479474
# TODO: @mazumdera: try optimizer offloading with adamw
480475
# Add gradient clipping if specified
481476
# Grad clipping to prevent large gradients. We find this
482477
# important to keep KL divergence in check.
483-
if tmvp_config.gradient_clipping_threshold > 0:
484-
optimizer = optax.chain(
485-
optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold),
486-
optimizer,
478+
def make_optimizer(learning_rate):
479+
transforms = []
480+
if tmvp_config.gradient_clipping_threshold > 0:
481+
transforms.append(optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold))
482+
transforms.append(
483+
optax.adamw(
484+
learning_rate=learning_rate,
485+
b1=tmvp_config.adam_b1,
486+
b2=tmvp_config.adam_b2,
487+
weight_decay=tmvp_config.adam_weight_decay,
488+
)
487489
)
488-
return optimizer
490+
return optax.chain(*transforms)
491+
492+
# Wrap the entire optimizer (including gradient clipping) with
493+
# inject_hyperparams so opt_state.hyperparams['learning_rate'] is at the
494+
# top level of the state tree. This is required for tunix's peft_trainer to
495+
# automatically read and log the per-step learning rate.
496+
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
489497

490498

491499
def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):

tests/post_training/unit/rl_utils_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,5 +332,42 @@ def test_without_hash(self):
332332
self.assertIsNone(utils_rl.extract_hash_answer(""))
333333

334334

335+
class TestGetOptimizer(unittest.TestCase):
336+
"""Tests for utils_rl.get_optimizer."""
337+
338+
def _make_optimizer_config(self, gradient_clipping_threshold=0.0):
339+
return SimpleNamespace(
340+
learning_rate=1e-4,
341+
warmup_steps_fraction=0.1,
342+
gradient_clipping_threshold=gradient_clipping_threshold,
343+
adam_b1=0.9,
344+
adam_b2=0.999,
345+
adam_weight_decay=0.01,
346+
)
347+
348+
@pytest.mark.cpu_only
349+
def test_returns_optimizer_without_clipping(self):
350+
"""get_optimizer returns an optax optimizer when gradient clipping is disabled."""
351+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
352+
353+
config = self._make_optimizer_config(gradient_clipping_threshold=0.0)
354+
opt = utils_rl.get_optimizer(config, max_train_steps=100)
355+
# Should be usable: init on a simple param tree
356+
params = {"w": jnp.ones(3)}
357+
state = opt.init(params)
358+
self.assertIn("learning_rate", state.hyperparams)
359+
360+
@pytest.mark.cpu_only
361+
def test_returns_optimizer_with_clipping(self):
362+
"""get_optimizer includes gradient clipping when threshold > 0."""
363+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
364+
365+
config = self._make_optimizer_config(gradient_clipping_threshold=1.0)
366+
opt = utils_rl.get_optimizer(config, max_train_steps=100)
367+
params = {"w": jnp.ones(3)}
368+
state = opt.init(params)
369+
self.assertIn("learning_rate", state.hyperparams)
370+
371+
335372
if __name__ == "__main__":
336373
unittest.main()

0 commit comments

Comments
 (0)