@@ -332,5 +332,42 @@ def test_without_hash(self):
332332 self .assertIsNone (utils_rl .extract_hash_answer ("" ))
333333
334334
335+ class TestGetOptimizer (unittest .TestCase ):
336+ """Tests for utils_rl.get_optimizer."""
337+
338+ def _make_optimizer_config (self , gradient_clipping_threshold = 0.0 ):
339+ return SimpleNamespace (
340+ learning_rate = 1e-4 ,
341+ warmup_steps_fraction = 0.1 ,
342+ gradient_clipping_threshold = gradient_clipping_threshold ,
343+ adam_b1 = 0.9 ,
344+ adam_b2 = 0.999 ,
345+ adam_weight_decay = 0.01 ,
346+ )
347+
348+ @pytest .mark .cpu_only
349+ def test_returns_optimizer_without_clipping (self ):
350+ """get_optimizer returns an optax optimizer when gradient clipping is disabled."""
351+ import jax .numpy as jnp # pylint: disable=import-outside-toplevel
352+
353+ config = self ._make_optimizer_config (gradient_clipping_threshold = 0.0 )
354+ opt = utils_rl .get_optimizer (config , max_train_steps = 100 )
355+ # Should be usable: init on a simple param tree
356+ params = {"w" : jnp .ones (3 )}
357+ state = opt .init (params )
358+ self .assertIn ("learning_rate" , state .hyperparams )
359+
360+ @pytest .mark .cpu_only
361+ def test_returns_optimizer_with_clipping (self ):
362+ """get_optimizer includes gradient clipping when threshold > 0."""
363+ import jax .numpy as jnp # pylint: disable=import-outside-toplevel
364+
365+ config = self ._make_optimizer_config (gradient_clipping_threshold = 1.0 )
366+ opt = utils_rl .get_optimizer (config , max_train_steps = 100 )
367+ params = {"w" : jnp .ones (3 )}
368+ state = opt .init (params )
369+ self .assertIn ("learning_rate" , state .hyperparams )
370+
371+
335372if __name__ == "__main__" :
336373 unittest .main ()
0 commit comments