Skip to content

Commit 2b1b592

Browse files
committed
add tests for getoptimizer
1 parent b922c80 commit 2b1b592

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

tests/post_training/unit/rl_utils_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,5 +332,42 @@ def test_without_hash(self):
332332
self.assertIsNone(utils_rl.extract_hash_answer(""))
333333

334334

335+
class TestGetOptimizer(unittest.TestCase):
336+
"""Tests for utils_rl.get_optimizer."""
337+
338+
def _make_optimizer_config(self, gradient_clipping_threshold=0.0):
339+
return SimpleNamespace(
340+
learning_rate=1e-4,
341+
warmup_steps_fraction=0.1,
342+
gradient_clipping_threshold=gradient_clipping_threshold,
343+
adam_b1=0.9,
344+
adam_b2=0.999,
345+
adam_weight_decay=0.01,
346+
)
347+
348+
@pytest.mark.cpu_only
349+
def test_returns_optimizer_without_clipping(self):
350+
"""get_optimizer returns an optax optimizer when gradient clipping is disabled."""
351+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
352+
353+
config = self._make_optimizer_config(gradient_clipping_threshold=0.0)
354+
opt = utils_rl.get_optimizer(config, max_train_steps=100)
355+
# Should be usable: init on a simple param tree
356+
params = {"w": jnp.ones(3)}
357+
state = opt.init(params)
358+
self.assertIn("learning_rate", state.hyperparams)
359+
360+
@pytest.mark.cpu_only
361+
def test_returns_optimizer_with_clipping(self):
362+
"""get_optimizer includes gradient clipping when threshold > 0."""
363+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
364+
365+
config = self._make_optimizer_config(gradient_clipping_threshold=1.0)
366+
opt = utils_rl.get_optimizer(config, max_train_steps=100)
367+
params = {"w": jnp.ones(3)}
368+
state = opt.init(params)
369+
self.assertIn("learning_rate", state.hyperparams)
370+
371+
335372
if __name__ == "__main__":
336373
unittest.main()

0 commit comments

Comments
 (0)