@@ -460,32 +460,40 @@ def extract_hash_answer(text: str) -> str | None:
460460
461461def get_optimizer (tmvp_config , max_train_steps ):
462462 """Function to obtain an optax optimizer, currently we use adamw."""
463- optimizer = optax .adamw (
464- learning_rate = optax .schedules .warmup_cosine_decay_schedule (
465- init_value = 0.0 ,
466- peak_value = tmvp_config .learning_rate ,
467- # Linearly increase learning rate from 0. to learning_rate in the first
468- # warmup_steps_fraction training steps, and then gradually decrease the
469- # learning rate to 0 using cosine scheduler.
470- warmup_steps = int (tmvp_config .warmup_steps_fraction * max_train_steps ),
471- decay_steps = max_train_steps ,
472- end_value = 0.0 ,
473- ),
474- b1 = tmvp_config .adam_b1 ,
475- b2 = tmvp_config .adam_b2 ,
476- weight_decay = tmvp_config .adam_weight_decay ,
463+ schedule = optax .schedules .warmup_cosine_decay_schedule (
464+ init_value = 0.0 ,
465+ peak_value = tmvp_config .learning_rate ,
466+ # Linearly increase learning rate from 0. to learning_rate in the first
467+ # warmup_steps_fraction training steps, and then gradually decrease the
468+ # learning rate to 0 using cosine scheduler.
469+ warmup_steps = int (tmvp_config .warmup_steps_fraction * max_train_steps ),
470+ decay_steps = max_train_steps ,
471+ end_value = 0.0 ,
477472 )
478473
479474 # TODO: @mazumdera: try optimizer offloading with adamw
480475 # Add gradient clipping if specified
481476 # Grad clipping to prevent large gradients. We find this
482477 # important to keep KL divergence in check.
483- if tmvp_config .gradient_clipping_threshold > 0 :
484- optimizer = optax .chain (
485- optax .clip_by_global_norm (max_norm = tmvp_config .gradient_clipping_threshold ),
486- optimizer ,
478+ def make_optimizer (learning_rate ):
479+ transforms = []
480+ if tmvp_config .gradient_clipping_threshold > 0 :
481+ transforms .append (optax .clip_by_global_norm (max_norm = tmvp_config .gradient_clipping_threshold ))
482+ transforms .append (
483+ optax .adamw (
484+ learning_rate = learning_rate ,
485+ b1 = tmvp_config .adam_b1 ,
486+ b2 = tmvp_config .adam_b2 ,
487+ weight_decay = tmvp_config .adam_weight_decay ,
488+ )
487489 )
488- return optimizer
490+ return optax .chain (* transforms )
491+
492+ # Wrap the entire optimizer (including gradient clipping) with
493+ # inject_hyperparams so opt_state.hyperparams['learning_rate'] is at the
494+ # top level of the state tree. This is required for tunix's peft_trainer to
495+ # automatically read and log the per-step learning rate.
496+ return optax .inject_hyperparams (make_optimizer )(learning_rate = schedule )
489497
490498
491499def process_data (dataset_name , model_tokenizer , template_config , tmvp_config , x ):
0 commit comments