4343from mp_actors import close_proxy , move_to_child_process
4444
4545from .. import dev
46+ from .._backend_training import (
47+ aggregate_rl_training_metrics ,
48+ build_rl_train_configs ,
49+ )
4650from ..backend import AnyTrainableModel , Backend
4751from ..metrics_taxonomy import (
4852 TRAIN_GRADIENT_STEPS_KEY ,
49- average_metric_samples ,
5053 build_training_summary_metrics ,
5154 summarize_trajectory_groups ,
5255)
@@ -598,45 +601,36 @@ async def train( # type: ignore[override]
598601 """
599602 groups_list = list (trajectory_groups )
600603
601- # Build config objects from explicit kwargs
602- config = TrainConfig (
603- learning_rate = learning_rate , kl_penalty_coef = kl_penalty_coef
604- )
605- dev_config : dev .TrainConfig = {
606- "advantage_balance" : advantage_balance ,
607- "allow_training_without_logprobs" : allow_training_without_logprobs ,
608- "importance_sampling_level" : importance_sampling_level ,
609- "kl_penalty_coef" : kl_penalty_coef ,
610- "mask_prob_ratio" : mask_prob_ratio ,
611- "plot_tensors" : plot_tensors ,
612- "ppo" : ppo ,
613- "precalculate_logprobs" : precalculate_logprobs ,
614- "scale_learning_rate_by_reward_std_dev" : scale_learning_rate_by_reward_std_dev ,
615- "scale_rewards" : scale_rewards ,
616- "logprob_calculation_chunk_size" : logprob_calculation_chunk_size ,
617- "num_trajectories_learning_rate_multiplier_power" : num_trajectories_learning_rate_multiplier_power ,
618- }
619- # Only include optional fields if they're set
620- if epsilon is not None :
621- dev_config ["epsilon" ] = epsilon
622- if epsilon_high is not None :
623- dev_config ["epsilon_high" ] = epsilon_high
624- if max_negative_advantage_importance_sampling_weight is not None :
625- dev_config ["max_negative_advantage_importance_sampling_weight" ] = (
626- max_negative_advantage_importance_sampling_weight
627- )
628- if kimi_k2_tau is not None :
629- dev_config ["kimi_k2_tau" ] = kimi_k2_tau
630- if truncated_importance_sampling is not None :
631- dev_config ["truncated_importance_sampling" ] = truncated_importance_sampling
632- if kl_ref_adapter_path is not None :
633- dev_config ["kl_ref_adapter_path" ] = kl_ref_adapter_path
634- elif kl_penalty_reference_step is not None :
635- ref_checkpoint_dir = get_step_checkpoint_dir (
604+ resolved_kl_ref_adapter_path = kl_ref_adapter_path
605+ if (
606+ resolved_kl_ref_adapter_path is None
607+ and kl_penalty_reference_step is not None
608+ ):
609+ resolved_kl_ref_adapter_path = get_step_checkpoint_dir (
636610 get_model_dir (model = model , art_path = self ._path ),
637611 kl_penalty_reference_step ,
638612 )
639- dev_config ["kl_ref_adapter_path" ] = ref_checkpoint_dir
613+ config , dev_config = build_rl_train_configs (
614+ learning_rate = learning_rate ,
615+ advantage_balance = advantage_balance ,
616+ scale_rewards = scale_rewards ,
617+ importance_sampling_level = importance_sampling_level ,
618+ mask_prob_ratio = mask_prob_ratio ,
619+ ppo = ppo ,
620+ precalculate_logprobs = precalculate_logprobs ,
621+ epsilon = epsilon ,
622+ epsilon_high = epsilon_high ,
623+ max_negative_advantage_importance_sampling_weight = max_negative_advantage_importance_sampling_weight ,
624+ kimi_k2_tau = kimi_k2_tau ,
625+ kl_penalty_coef = kl_penalty_coef ,
626+ allow_training_without_logprobs = allow_training_without_logprobs ,
627+ plot_tensors = plot_tensors ,
628+ truncated_importance_sampling = truncated_importance_sampling ,
629+ scale_learning_rate_by_reward_std_dev = scale_learning_rate_by_reward_std_dev ,
630+ logprob_calculation_chunk_size = logprob_calculation_chunk_size ,
631+ num_trajectories_learning_rate_multiplier_power = num_trajectories_learning_rate_multiplier_power ,
632+ kl_ref_adapter_path = resolved_kl_ref_adapter_path ,
633+ )
640634
641635 # Collect metrics from training
642636 training_metrics : list [dict [str , float ]] = []
@@ -646,21 +640,10 @@ async def train( # type: ignore[override]
646640 ):
647641 training_metrics .append (metrics )
648642
649- # Aggregate metrics
650- avg_metrics = average_metric_samples (training_metrics )
651- summary = summarize_trajectory_groups (groups_list )
652- avg_metrics .setdefault (
653- "time/step_trainer_s" , time .monotonic () - trainer_started
654- )
655- avg_metrics .update (
656- {
657- key : value
658- for key , value in build_training_summary_metrics (
659- summary ,
660- include_trainable_groups = True ,
661- ).items ()
662- if key not in avg_metrics
663- }
643+ avg_metrics = aggregate_rl_training_metrics (
644+ training_metrics = training_metrics ,
645+ trajectory_groups = groups_list ,
646+ trainer_started = trainer_started ,
664647 )
665648
666649 # Get step and checkpoint path
0 commit comments