Skip to content

Commit b19e94c

Browse files
committed
Minimize LocalBackend diff against main
1 parent d8d8e4d commit b19e94c

1 file changed

Lines changed: 58 additions & 39 deletions

File tree

src/art/local/backend.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
build_rl_train_configs,
4949
)
5050
from ..backend import AnyTrainableModel, Backend
51+
from ..costs import build_cost_calculator, get_model_pricing
5152
from ..metrics_taxonomy import (
5253
TRAIN_GRADIENT_STEPS_KEY,
5354
build_training_summary_metrics,
@@ -185,26 +186,23 @@ async def close(self) -> None:
185186
"""
186187
If running vLLM in a separate process, this will kill that process and close the communication threads.
187188
"""
188-
await self._aclose()
189+
for service in self._services.values():
190+
aclose = getattr(service, "aclose", None)
191+
if aclose is None:
192+
close = getattr(service, "close", None)
193+
if close is not None:
194+
close()
195+
else:
196+
await aclose()
197+
close_proxy(service)
189198

190199
def _close(self) -> None:
191-
for _, service in self._services.items():
200+
for service in self._services.values():
192201
close = getattr(service, "close", None)
193202
if close is not None:
194203
close()
195204
close_proxy(service)
196205

197-
async def _aclose(self) -> None:
198-
for _, service in self._services.items():
199-
aclose = getattr(service, "aclose", None)
200-
if aclose is not None:
201-
await aclose()
202-
else:
203-
close = getattr(service, "close", None)
204-
if close is not None:
205-
close()
206-
close_proxy(service)
207-
208206
async def register(
209207
self,
210208
model: Model,
@@ -231,6 +229,11 @@ async def register(
231229
# (wandb initialization is now handled by the model's _get_wandb_run method)
232230
if model.trainable and "WANDB_API_KEY" in os.environ:
233231
_ = model._get_wandb_run()
232+
if model.trainable:
233+
trainable_model = cast(TrainableModel, model)
234+
pricing = get_model_pricing(trainable_model.base_model)
235+
if pricing is not None:
236+
trainable_model.set_cost_calculator(build_cost_calculator(pricing))
234237

235238
def _model_inference_name(self, model: Model, step: int | None = None) -> str:
236239
"""Return the inference name for a model checkpoint.
@@ -244,25 +247,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
244247
If None, returns name for latest checkpoint (step 0 initially).
245248
"""
246249

247-
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
248-
# Default to step 0 when not specified (the initial checkpoint created at registration)
249-
if step is not None:
250-
actual_step = step
251-
elif model.name in self._services and self._in_process:
252-
# In dedicated mode the service tracks which adapter vLLM has
253-
# actually loaded. Reading the filesystem would race: the
254-
# checkpoint directory appears before the HTTP reload completes.
255-
svc = self._services[model.name]
256-
loaded_step = getattr(svc, "_latest_step", None)
257-
actual_step = (
258-
loaded_step if loaded_step is not None else self.__get_step(model)
259-
)
260-
else:
261-
actual_step = self.__get_step(model)
262-
name = f"{model.name}@{actual_step}"
250+
requested_step = step
251+
252+
if step is None and isinstance(model, TrainableModel):
253+
from ..dev.validate import is_dedicated_mode
254+
255+
service = self._services.get(model.name)
256+
if service is not None and is_dedicated_mode(
257+
model._internal_config or dev.InternalModelConfig()
258+
):
259+
loaded_step = getattr(service, "_latest_step", None)
260+
if isinstance(loaded_step, int):
261+
step = loaded_step
262+
263+
if step is None:
264+
# The checkpoint directory is written before dedicated-mode
265+
# vLLM finishes reloading the new adapter.
266+
step = self.__get_step(model)
267+
name = f"{model.name}@{step}"
263268
logger.debug(
264-
f"[BACKEND] _model_inference_name: step_arg={step} "
265-
f"actual_step={actual_step} -> {name}"
269+
f"[BACKEND] _model_inference_name: step_arg={requested_step} "
270+
f"actual_step={step} -> {name}"
266271
)
267272
return name
268273

@@ -527,13 +532,14 @@ async def train( # type: ignore[override]
527532
*,
528533
# Core training parameters
529534
learning_rate: float = 5e-6,
530-
loss_fn: Literal["cispo", "ppo"] | None = None,
535+
loss_fn: Literal["cispo", "ppo"] = "cispo",
536+
loss_fn_config: dict | None = None,
537+
normalize_advantages: bool = True,
538+
adam_params: object | None = None,
531539
# KL-penalized advantage adjustment
532540
kl_penalty_coef: float = 0.0,
533541
kl_penalty_reference_step: int | None = None,
534542
kl_ref_adapter_path: str | None = None,
535-
# RL algorithm settings
536-
ppo: bool = False,
537543
epsilon: float | None = None,
538544
epsilon_high: float | None = None,
539545
# Advantage computation
@@ -570,6 +576,14 @@ async def train( # type: ignore[override]
570576
model: The trainable model to train.
571577
trajectory_groups: Batches of trajectories to train on.
572578
learning_rate: Learning rate for training. Defaults to 5e-6.
579+
loss_fn: RL loss function. LocalBackend currently supports
580+
"cispo" and "ppo".
581+
loss_fn_config: Additional loss-function config. Not supported by
582+
LocalBackend.
583+
normalize_advantages: Whether to normalize advantages. LocalBackend
584+
currently requires True.
585+
adam_params: Custom optimizer params. Not supported by
586+
LocalBackend.
573587
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
574588
Tokens diverging more from the reference get reduced advantages.
575589
Defaults to 0.0 (disabled).
@@ -579,8 +593,7 @@ async def train( # type: ignore[override]
579593
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
580594
checkpoint to use as the KL reference. Alternative to
581595
kl_penalty_reference_step.
582-
ppo: Whether to use PPO clipping. Defaults to False.
583-
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
596+
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
584597
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
585598
advantage_balance: Balance between negative and positive advantages
586599
in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
@@ -623,8 +636,14 @@ async def train( # type: ignore[override]
623636
# await model.log(metrics=result.metrics, step=result.step)
624637
"""
625638
groups_list = list(trajectory_groups)
626-
if loss_fn is not None:
627-
ppo = loss_fn == "ppo"
639+
if loss_fn not in {"cispo", "ppo"}:
640+
raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.")
641+
if loss_fn_config is not None:
642+
raise ValueError("LocalBackend requires loss_fn_config=None.")
643+
if not normalize_advantages:
644+
raise ValueError("LocalBackend requires normalize_advantages=True.")
645+
if adam_params is not None:
646+
raise ValueError("LocalBackend requires adam_params=None.")
628647

629648
resolved_kl_ref_adapter_path = kl_ref_adapter_path
630649
if (
@@ -641,7 +660,7 @@ async def train( # type: ignore[override]
641660
scale_rewards=scale_rewards,
642661
importance_sampling_level=importance_sampling_level,
643662
mask_prob_ratio=mask_prob_ratio,
644-
ppo=ppo,
663+
ppo=loss_fn == "ppo",
645664
precalculate_logprobs=precalculate_logprobs,
646665
epsilon=epsilon,
647666
epsilon_high=epsilon_high,

0 commit comments

Comments
 (0)