4343from mp_actors import close_proxy , move_to_child_process
4444
4545from .. import dev
46+ from .._backend_training import (
47+ aggregate_rl_training_metrics ,
48+ build_rl_train_configs ,
49+ )
4650from ..backend import AnyTrainableModel , Backend
47- from ..costs import build_cost_calculator , get_model_pricing
4851from ..metrics_taxonomy import (
4952 TRAIN_GRADIENT_STEPS_KEY ,
50- average_metric_samples ,
5153 build_training_summary_metrics ,
5254 summarize_trajectory_groups ,
5355)
@@ -160,9 +162,6 @@ def _allocated_gpu_count(self, model: Model) -> int:
160162 def __enter__ (self ) -> Self :
161163 return self
162164
163- async def __aenter__ (self ) -> Self :
164- return self
165-
166165 def __exit__ (
167166 self ,
168167 exc_type : type [BaseException ] | None ,
@@ -171,30 +170,14 @@ def __exit__(
171170 ) -> None :
172171 self ._close ()
173172
174- async def __aexit__ (
175- self ,
176- exc_type : type [BaseException ] | None ,
177- exc : BaseException | None ,
178- tb : TracebackType | None ,
179- ) -> None :
180- await self .close ()
181-
182173 async def close (self ) -> None :
183174 """
184175 If running vLLM in a separate process, this will kill that process and close the communication threads.
185176 """
186- for service in self ._services .values ():
187- aclose = getattr (service , "aclose" , None )
188- if aclose is None :
189- close = getattr (service , "close" , None )
190- if close is not None :
191- close ()
192- else :
193- await aclose ()
194- close_proxy (service )
177+ self ._close ()
195178
196179 def _close (self ) -> None :
197- for service in self ._services .values ():
180+ for _ , service in self ._services .items ():
198181 close = getattr (service , "close" , None )
199182 if close is not None :
200183 close ()
@@ -226,11 +209,6 @@ async def register(
226209 # (wandb initialization is now handled by the model's _get_wandb_run method)
227210 if model .trainable and "WANDB_API_KEY" in os .environ :
228211 _ = model ._get_wandb_run ()
229- if model .trainable :
230- trainable_model = cast (TrainableModel , model )
231- pricing = get_model_pricing (trainable_model .base_model )
232- if pricing is not None :
233- trainable_model .set_cost_calculator (build_cost_calculator (pricing ))
234212
235213 def _model_inference_name (self , model : Model , step : int | None = None ) -> str :
236214 """Return the inference name for a model checkpoint.
@@ -244,27 +222,25 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
244222 If None, returns name for latest checkpoint (step 0 initially).
245223 """
246224
247- requested_step = step
248-
249- if step is None and isinstance (model , TrainableModel ):
250- from ..dev .validate import is_dedicated_mode
251-
252- service = self ._services .get (model .name )
253- if service is not None and is_dedicated_mode (
254- model ._internal_config or dev .InternalModelConfig ()
255- ):
256- loaded_step = getattr (service , "_latest_step" , None )
257- if isinstance (loaded_step , int ):
258- step = loaded_step
259-
260- if step is None :
261- # The checkpoint directory is written before dedicated-mode
262- # vLLM finishes reloading the new adapter.
263- step = self .__get_step (model )
264- name = f"{ model .name } @{ step } "
225+ # For LocalBackend, vLLM always serves LoRA adapters with @step suffix
226+ # Default to step 0 when not specified (the initial checkpoint created at registration)
227+ if step is not None :
228+ actual_step = step
229+ elif model .name in self ._services and self ._in_process :
230+ # In dedicated mode the service tracks which adapter vLLM has
231+ # actually loaded. Reading the filesystem would race: the
232+ # checkpoint directory appears before the HTTP reload completes.
233+ svc = self ._services [model .name ]
234+ loaded_step = getattr (svc , "_latest_step" , None )
235+ actual_step = (
236+ loaded_step if loaded_step is not None else self .__get_step (model )
237+ )
238+ else :
239+ actual_step = self .__get_step (model )
240+ name = f"{ model .name } @{ actual_step } "
265241 logger .debug (
266- f"[BACKEND] _model_inference_name: step_arg={ requested_step } "
267- f"actual_step={ step } -> { name } "
242+ f"[BACKEND] _model_inference_name: step_arg={ step } "
243+ f"actual_step={ actual_step } -> { name } "
268244 )
269245 return name
270246
@@ -529,14 +505,12 @@ async def train( # type: ignore[override]
529505 * ,
530506 # Core training parameters
531507 learning_rate : float = 5e-6 ,
532- loss_fn : Literal ["cispo" , "ppo" ] = "cispo" ,
533- loss_fn_config : dict | None = None ,
534- normalize_advantages : bool = True ,
535- adam_params : object | None = None ,
536508 # KL-penalized advantage adjustment
537509 kl_penalty_coef : float = 0.0 ,
538510 kl_penalty_reference_step : int | None = None ,
539511 kl_ref_adapter_path : str | None = None ,
512+ # RL algorithm settings
513+ ppo : bool = False ,
540514 epsilon : float | None = None ,
541515 epsilon_high : float | None = None ,
542516 # Advantage computation
@@ -573,14 +547,6 @@ async def train( # type: ignore[override]
573547 model: The trainable model to train.
574548 trajectory_groups: Batches of trajectories to train on.
575549 learning_rate: Learning rate for training. Defaults to 5e-6.
576- loss_fn: RL loss function. LocalBackend currently supports
577- "cispo" and "ppo".
578- loss_fn_config: Additional loss-function config. Not supported by
579- LocalBackend.
580- normalize_advantages: Whether to normalize advantages. LocalBackend
581- currently requires True.
582- adam_params: Custom optimizer params. Not supported by
583- LocalBackend.
584550 kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
585551 Tokens diverging more from the reference get reduced advantages.
586552 Defaults to 0.0 (disabled).
@@ -590,7 +556,8 @@ async def train( # type: ignore[override]
590556 kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
591557 checkpoint to use as the KL reference. Alternative to
592558 kl_penalty_reference_step.
593- epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
559+ ppo: Whether to use PPO clipping. Defaults to False.
560+ epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
594561 epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
595562 advantage_balance: Balance between negative and positive advantages
596563 in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
@@ -633,54 +600,37 @@ async def train( # type: ignore[override]
633600 # await model.log(metrics=result.metrics, step=result.step)
634601 """
635602 groups_list = list (trajectory_groups )
636- if loss_fn not in {"cispo" , "ppo" }:
637- raise ValueError ("LocalBackend only supports loss_fn='cispo' or 'ppo'." )
638- if loss_fn_config is not None :
639- raise ValueError ("LocalBackend requires loss_fn_config=None." )
640- if not normalize_advantages :
641- raise ValueError ("LocalBackend requires normalize_advantages=True." )
642- if adam_params is not None :
643- raise ValueError ("LocalBackend requires adam_params=None." )
644-
645- # Build config objects from explicit kwargs
646- config = TrainConfig (
647- learning_rate = learning_rate , kl_penalty_coef = kl_penalty_coef
648- )
649- dev_config : dev .TrainConfig = {
650- "advantage_balance" : advantage_balance ,
651- "allow_training_without_logprobs" : allow_training_without_logprobs ,
652- "importance_sampling_level" : importance_sampling_level ,
653- "kl_penalty_coef" : kl_penalty_coef ,
654- "mask_prob_ratio" : mask_prob_ratio ,
655- "plot_tensors" : plot_tensors ,
656- "ppo" : loss_fn == "ppo" ,
657- "precalculate_logprobs" : precalculate_logprobs ,
658- "scale_learning_rate_by_reward_std_dev" : scale_learning_rate_by_reward_std_dev ,
659- "scale_rewards" : scale_rewards ,
660- "logprob_calculation_chunk_size" : logprob_calculation_chunk_size ,
661- "num_trajectories_learning_rate_multiplier_power" : num_trajectories_learning_rate_multiplier_power ,
662- }
663- # Only include optional fields if they're set
664- if epsilon is not None :
665- dev_config ["epsilon" ] = epsilon
666- if epsilon_high is not None :
667- dev_config ["epsilon_high" ] = epsilon_high
668- if max_negative_advantage_importance_sampling_weight is not None :
669- dev_config ["max_negative_advantage_importance_sampling_weight" ] = (
670- max_negative_advantage_importance_sampling_weight
671- )
672- if kimi_k2_tau is not None :
673- dev_config ["kimi_k2_tau" ] = kimi_k2_tau
674- if truncated_importance_sampling is not None :
675- dev_config ["truncated_importance_sampling" ] = truncated_importance_sampling
676- if kl_ref_adapter_path is not None :
677- dev_config ["kl_ref_adapter_path" ] = kl_ref_adapter_path
678- elif kl_penalty_reference_step is not None :
679- ref_checkpoint_dir = get_step_checkpoint_dir (
603+
604+ resolved_kl_ref_adapter_path = kl_ref_adapter_path
605+ if (
606+ resolved_kl_ref_adapter_path is None
607+ and kl_penalty_reference_step is not None
608+ ):
609+ resolved_kl_ref_adapter_path = get_step_checkpoint_dir (
680610 get_model_dir (model = model , art_path = self ._path ),
681611 kl_penalty_reference_step ,
682612 )
683- dev_config ["kl_ref_adapter_path" ] = ref_checkpoint_dir
613+ config , dev_config = build_rl_train_configs (
614+ learning_rate = learning_rate ,
615+ advantage_balance = advantage_balance ,
616+ scale_rewards = scale_rewards ,
617+ importance_sampling_level = importance_sampling_level ,
618+ mask_prob_ratio = mask_prob_ratio ,
619+ ppo = ppo ,
620+ precalculate_logprobs = precalculate_logprobs ,
621+ epsilon = epsilon ,
622+ epsilon_high = epsilon_high ,
623+ max_negative_advantage_importance_sampling_weight = max_negative_advantage_importance_sampling_weight ,
624+ kimi_k2_tau = kimi_k2_tau ,
625+ kl_penalty_coef = kl_penalty_coef ,
626+ allow_training_without_logprobs = allow_training_without_logprobs ,
627+ plot_tensors = plot_tensors ,
628+ truncated_importance_sampling = truncated_importance_sampling ,
629+ scale_learning_rate_by_reward_std_dev = scale_learning_rate_by_reward_std_dev ,
630+ logprob_calculation_chunk_size = logprob_calculation_chunk_size ,
631+ num_trajectories_learning_rate_multiplier_power = num_trajectories_learning_rate_multiplier_power ,
632+ kl_ref_adapter_path = resolved_kl_ref_adapter_path ,
633+ )
684634
685635 # Collect metrics from training
686636 training_metrics : list [dict [str , float ]] = []
@@ -690,21 +640,10 @@ async def train( # type: ignore[override]
690640 ):
691641 training_metrics .append (metrics )
692642
693- # Aggregate metrics
694- avg_metrics = average_metric_samples (training_metrics )
695- summary = summarize_trajectory_groups (groups_list )
696- avg_metrics .setdefault (
697- "time/step_trainer_s" , time .monotonic () - trainer_started
698- )
699- avg_metrics .update (
700- {
701- key : value
702- for key , value in build_training_summary_metrics (
703- summary ,
704- include_trainable_groups = True ,
705- ).items ()
706- if key not in avg_metrics
707- }
643+ avg_metrics = aggregate_rl_training_metrics (
644+ training_metrics = training_metrics ,
645+ trajectory_groups = groups_list ,
646+ trainer_started = trainer_started ,
708647 )
709648
710649 # Get step and checkpoint path
0 commit comments