Skip to content

Commit 45e21f8

Browse files
Merge pull request #3444 from AI-Hypercomputer:atwigg/log_lr_hparam
PiperOrigin-RevId: 888845972
2 parents e6cd443 + 19eebe9 commit 45e21f8

2 files changed

Lines changed: 64 additions & 19 deletions

File tree

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -478,32 +478,40 @@ def extract_hash_answer(text: str) -> str | None:
478478

479479
def get_optimizer(tmvp_config, max_train_steps):
480480
"""Function to obtain an optax optimizer, currently we use adamw."""
481-
optimizer = optax.adamw(
482-
learning_rate=optax.schedules.warmup_cosine_decay_schedule(
483-
init_value=0.0,
484-
peak_value=tmvp_config.learning_rate,
485-
# Linearly increase learning rate from 0. to learning_rate in the first
486-
# warmup_steps_fraction training steps, and then gradually decrease the
487-
# learning rate to 0 using cosine scheduler.
488-
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
489-
decay_steps=max_train_steps,
490-
end_value=0.0,
491-
),
492-
b1=tmvp_config.adam_b1,
493-
b2=tmvp_config.adam_b2,
494-
weight_decay=tmvp_config.adam_weight_decay,
481+
schedule = optax.schedules.warmup_cosine_decay_schedule(
482+
init_value=0.0,
483+
peak_value=tmvp_config.learning_rate,
484+
# Linearly increase learning rate from 0. to learning_rate in the first
485+
# warmup_steps_fraction training steps, and then gradually decrease the
486+
# learning rate to 0 using cosine scheduler.
487+
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
488+
decay_steps=max_train_steps,
489+
end_value=0.0,
495490
)
496491

497492
# TODO: @mazumdera: try optimizer offloading with adamw
498493
# Add gradient clipping if specified
499494
# Grad clipping to prevent large gradients. We find this
500495
# important to keep KL divergence in check.
501-
if tmvp_config.gradient_clipping_threshold > 0:
502-
optimizer = optax.chain(
503-
optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold),
504-
optimizer,
496+
def make_optimizer(learning_rate):
497+
transforms = []
498+
if tmvp_config.gradient_clipping_threshold > 0:
499+
transforms.append(optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold))
500+
transforms.append(
501+
optax.adamw(
502+
learning_rate=learning_rate,
503+
b1=tmvp_config.adam_b1,
504+
b2=tmvp_config.adam_b2,
505+
weight_decay=tmvp_config.adam_weight_decay,
506+
)
505507
)
506-
return optimizer
508+
return optax.chain(*transforms)
509+
510+
# Wrap the entire optimizer (including gradient clipping) with
511+
# inject_hyperparams so opt_state.hyperparams['learning_rate'] is at the
512+
# top level of the state tree. This is required for tunix's peft_trainer to
513+
# automatically read and log the per-step learning rate.
514+
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
507515

508516

509517
def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):

tests/post_training/unit/rl_utils_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,5 +332,42 @@ def test_without_hash(self):
332332
self.assertIsNone(utils_rl.extract_hash_answer(""))
333333

334334

335+
class TestGetOptimizer(unittest.TestCase):
336+
"""Tests for utils_rl.get_optimizer."""
337+
338+
def _make_optimizer_config(self, gradient_clipping_threshold=0.0):
339+
return SimpleNamespace(
340+
learning_rate=1e-4,
341+
warmup_steps_fraction=0.1,
342+
gradient_clipping_threshold=gradient_clipping_threshold,
343+
adam_b1=0.9,
344+
adam_b2=0.999,
345+
adam_weight_decay=0.01,
346+
)
347+
348+
@pytest.mark.cpu_only
349+
def test_returns_optimizer_without_clipping(self):
350+
"""get_optimizer returns an optax optimizer when gradient clipping is disabled."""
351+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
352+
353+
config = self._make_optimizer_config(gradient_clipping_threshold=0.0)
354+
opt = utils_rl.get_optimizer(config, max_train_steps=100)
355+
# Should be usable: init on a simple param tree
356+
params = {"w": jnp.ones(3)}
357+
state = opt.init(params)
358+
self.assertIn("learning_rate", state.hyperparams)
359+
360+
@pytest.mark.cpu_only
361+
def test_returns_optimizer_with_clipping(self):
362+
"""get_optimizer includes gradient clipping when threshold > 0."""
363+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
364+
365+
config = self._make_optimizer_config(gradient_clipping_threshold=1.0)
366+
opt = utils_rl.get_optimizer(config, max_train_steps=100)
367+
params = {"w": jnp.ones(3)}
368+
state = opt.init(params)
369+
self.assertIn("learning_rate", state.hyperparams)
370+
371+
335372
if __name__ == "__main__":
336373
unittest.main()

0 commit comments

Comments
 (0)