Skip to content

Commit 1e0cc4c

Browse files
committed
Refactor shared Megatron and Unsloth training code
1 parent 621e82b commit 1e0cc4c

File tree

10 files changed

+1332
-980
lines changed

10 files changed

+1332
-980
lines changed

src/art/_backend_training.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from collections.abc import Iterable
2+
import time
3+
from typing import Literal
4+
5+
from . import dev
6+
from .metrics_taxonomy import (
7+
average_metric_samples,
8+
build_training_summary_metrics,
9+
summarize_trajectory_groups,
10+
)
11+
from .trajectories import TrajectoryGroup
12+
from .types import TrainConfig
13+
14+
15+
def build_rl_train_configs(
16+
*,
17+
learning_rate: float,
18+
advantage_balance: float = 0.0,
19+
scale_rewards: bool = True,
20+
importance_sampling_level: Literal[
21+
"token", "sequence", "average", "geometric_average"
22+
] = "token",
23+
mask_prob_ratio: bool = False,
24+
ppo: bool = False,
25+
precalculate_logprobs: bool = False,
26+
epsilon: float | None = None,
27+
epsilon_high: float | None = None,
28+
max_negative_advantage_importance_sampling_weight: float | None = None,
29+
kimi_k2_tau: float | None = None,
30+
kl_penalty_coef: float = 0.0,
31+
allow_training_without_logprobs: bool | None = None,
32+
plot_tensors: bool | None = None,
33+
truncated_importance_sampling: float | None = None,
34+
scale_learning_rate_by_reward_std_dev: bool | None = None,
35+
logprob_calculation_chunk_size: int | None = None,
36+
num_trajectories_learning_rate_multiplier_power: float | None = None,
37+
kl_ref_adapter_path: str | None = None,
38+
) -> tuple[TrainConfig, dev.TrainConfig]:
39+
config = TrainConfig(
40+
learning_rate=learning_rate,
41+
kl_penalty_coef=kl_penalty_coef,
42+
)
43+
dev_config: dev.TrainConfig = {
44+
"advantage_balance": advantage_balance,
45+
"importance_sampling_level": importance_sampling_level,
46+
"kl_penalty_coef": kl_penalty_coef,
47+
"mask_prob_ratio": mask_prob_ratio,
48+
"ppo": ppo,
49+
"precalculate_logprobs": precalculate_logprobs,
50+
"scale_rewards": scale_rewards,
51+
}
52+
53+
if allow_training_without_logprobs is not None:
54+
dev_config["allow_training_without_logprobs"] = (
55+
allow_training_without_logprobs
56+
)
57+
if plot_tensors is not None:
58+
dev_config["plot_tensors"] = plot_tensors
59+
if truncated_importance_sampling is not None:
60+
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
61+
if scale_learning_rate_by_reward_std_dev is not None:
62+
dev_config["scale_learning_rate_by_reward_std_dev"] = (
63+
scale_learning_rate_by_reward_std_dev
64+
)
65+
if logprob_calculation_chunk_size is not None:
66+
dev_config["logprob_calculation_chunk_size"] = (
67+
logprob_calculation_chunk_size
68+
)
69+
if num_trajectories_learning_rate_multiplier_power is not None:
70+
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
71+
num_trajectories_learning_rate_multiplier_power
72+
)
73+
if epsilon is not None:
74+
dev_config["epsilon"] = epsilon
75+
if epsilon_high is not None:
76+
dev_config["epsilon_high"] = epsilon_high
77+
if max_negative_advantage_importance_sampling_weight is not None:
78+
dev_config["max_negative_advantage_importance_sampling_weight"] = (
79+
max_negative_advantage_importance_sampling_weight
80+
)
81+
if kimi_k2_tau is not None:
82+
dev_config["kimi_k2_tau"] = kimi_k2_tau
83+
if kl_ref_adapter_path is not None:
84+
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
85+
86+
return config, dev_config
87+
88+
89+
def aggregate_rl_training_metrics(
90+
*,
91+
training_metrics: list[dict[str, float]],
92+
trajectory_groups: Iterable[TrajectoryGroup],
93+
trainer_started: float,
94+
) -> dict[str, float]:
95+
groups_list = list(trajectory_groups)
96+
avg_metrics = average_metric_samples(training_metrics)
97+
summary = summarize_trajectory_groups(groups_list)
98+
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
99+
avg_metrics.update(
100+
{
101+
key: value
102+
for key, value in build_training_summary_metrics(
103+
summary,
104+
include_trainable_groups=True,
105+
).items()
106+
if key not in avg_metrics
107+
}
108+
)
109+
return avg_metrics

src/art/local/backend.py

100644100755
Lines changed: 60 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@
4343
from mp_actors import close_proxy, move_to_child_process
4444

4545
from .. import dev
46+
from .._backend_training import (
47+
aggregate_rl_training_metrics,
48+
build_rl_train_configs,
49+
)
4650
from ..backend import AnyTrainableModel, Backend
47-
from ..costs import build_cost_calculator, get_model_pricing
4851
from ..metrics_taxonomy import (
4952
TRAIN_GRADIENT_STEPS_KEY,
50-
average_metric_samples,
5153
build_training_summary_metrics,
5254
summarize_trajectory_groups,
5355
)
@@ -160,9 +162,6 @@ def _allocated_gpu_count(self, model: Model) -> int:
160162
def __enter__(self) -> Self:
161163
return self
162164

163-
async def __aenter__(self) -> Self:
164-
return self
165-
166165
def __exit__(
167166
self,
168167
exc_type: type[BaseException] | None,
@@ -171,30 +170,14 @@ def __exit__(
171170
) -> None:
172171
self._close()
173172

174-
async def __aexit__(
175-
self,
176-
exc_type: type[BaseException] | None,
177-
exc: BaseException | None,
178-
tb: TracebackType | None,
179-
) -> None:
180-
await self.close()
181-
182173
async def close(self) -> None:
183174
"""
184175
If running vLLM in a separate process, this will kill that process and close the communication threads.
185176
"""
186-
for service in self._services.values():
187-
aclose = getattr(service, "aclose", None)
188-
if aclose is None:
189-
close = getattr(service, "close", None)
190-
if close is not None:
191-
close()
192-
else:
193-
await aclose()
194-
close_proxy(service)
177+
self._close()
195178

196179
def _close(self) -> None:
197-
for service in self._services.values():
180+
for _, service in self._services.items():
198181
close = getattr(service, "close", None)
199182
if close is not None:
200183
close()
@@ -226,11 +209,6 @@ async def register(
226209
# (wandb initialization is now handled by the model's _get_wandb_run method)
227210
if model.trainable and "WANDB_API_KEY" in os.environ:
228211
_ = model._get_wandb_run()
229-
if model.trainable:
230-
trainable_model = cast(TrainableModel, model)
231-
pricing = get_model_pricing(trainable_model.base_model)
232-
if pricing is not None:
233-
trainable_model.set_cost_calculator(build_cost_calculator(pricing))
234212

235213
def _model_inference_name(self, model: Model, step: int | None = None) -> str:
236214
"""Return the inference name for a model checkpoint.
@@ -244,27 +222,25 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
244222
If None, returns name for latest checkpoint (step 0 initially).
245223
"""
246224

247-
requested_step = step
248-
249-
if step is None and isinstance(model, TrainableModel):
250-
from ..dev.validate import is_dedicated_mode
251-
252-
service = self._services.get(model.name)
253-
if service is not None and is_dedicated_mode(
254-
model._internal_config or dev.InternalModelConfig()
255-
):
256-
loaded_step = getattr(service, "_latest_step", None)
257-
if isinstance(loaded_step, int):
258-
step = loaded_step
259-
260-
if step is None:
261-
# The checkpoint directory is written before dedicated-mode
262-
# vLLM finishes reloading the new adapter.
263-
step = self.__get_step(model)
264-
name = f"{model.name}@{step}"
225+
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
226+
# Default to step 0 when not specified (the initial checkpoint created at registration)
227+
if step is not None:
228+
actual_step = step
229+
elif model.name in self._services and self._in_process:
230+
# In dedicated mode the service tracks which adapter vLLM has
231+
# actually loaded. Reading the filesystem would race: the
232+
# checkpoint directory appears before the HTTP reload completes.
233+
svc = self._services[model.name]
234+
loaded_step = getattr(svc, "_latest_step", None)
235+
actual_step = (
236+
loaded_step if loaded_step is not None else self.__get_step(model)
237+
)
238+
else:
239+
actual_step = self.__get_step(model)
240+
name = f"{model.name}@{actual_step}"
265241
logger.debug(
266-
f"[BACKEND] _model_inference_name: step_arg={requested_step} "
267-
f"actual_step={step} -> {name}"
242+
f"[BACKEND] _model_inference_name: step_arg={step} "
243+
f"actual_step={actual_step} -> {name}"
268244
)
269245
return name
270246

@@ -529,14 +505,12 @@ async def train( # type: ignore[override]
529505
*,
530506
# Core training parameters
531507
learning_rate: float = 5e-6,
532-
loss_fn: Literal["cispo", "ppo"] = "cispo",
533-
loss_fn_config: dict | None = None,
534-
normalize_advantages: bool = True,
535-
adam_params: object | None = None,
536508
# KL-penalized advantage adjustment
537509
kl_penalty_coef: float = 0.0,
538510
kl_penalty_reference_step: int | None = None,
539511
kl_ref_adapter_path: str | None = None,
512+
# RL algorithm settings
513+
ppo: bool = False,
540514
epsilon: float | None = None,
541515
epsilon_high: float | None = None,
542516
# Advantage computation
@@ -573,14 +547,6 @@ async def train( # type: ignore[override]
573547
model: The trainable model to train.
574548
trajectory_groups: Batches of trajectories to train on.
575549
learning_rate: Learning rate for training. Defaults to 5e-6.
576-
loss_fn: RL loss function. LocalBackend currently supports
577-
"cispo" and "ppo".
578-
loss_fn_config: Additional loss-function config. Not supported by
579-
LocalBackend.
580-
normalize_advantages: Whether to normalize advantages. LocalBackend
581-
currently requires True.
582-
adam_params: Custom optimizer params. Not supported by
583-
LocalBackend.
584550
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
585551
Tokens diverging more from the reference get reduced advantages.
586552
Defaults to 0.0 (disabled).
@@ -590,7 +556,8 @@ async def train( # type: ignore[override]
590556
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
591557
checkpoint to use as the KL reference. Alternative to
592558
kl_penalty_reference_step.
593-
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
559+
ppo: Whether to use PPO clipping. Defaults to False.
560+
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
594561
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
595562
advantage_balance: Balance between negative and positive advantages
596563
in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
@@ -633,54 +600,37 @@ async def train( # type: ignore[override]
633600
# await model.log(metrics=result.metrics, step=result.step)
634601
"""
635602
groups_list = list(trajectory_groups)
636-
if loss_fn not in {"cispo", "ppo"}:
637-
raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.")
638-
if loss_fn_config is not None:
639-
raise ValueError("LocalBackend requires loss_fn_config=None.")
640-
if not normalize_advantages:
641-
raise ValueError("LocalBackend requires normalize_advantages=True.")
642-
if adam_params is not None:
643-
raise ValueError("LocalBackend requires adam_params=None.")
644-
645-
# Build config objects from explicit kwargs
646-
config = TrainConfig(
647-
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
648-
)
649-
dev_config: dev.TrainConfig = {
650-
"advantage_balance": advantage_balance,
651-
"allow_training_without_logprobs": allow_training_without_logprobs,
652-
"importance_sampling_level": importance_sampling_level,
653-
"kl_penalty_coef": kl_penalty_coef,
654-
"mask_prob_ratio": mask_prob_ratio,
655-
"plot_tensors": plot_tensors,
656-
"ppo": loss_fn == "ppo",
657-
"precalculate_logprobs": precalculate_logprobs,
658-
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
659-
"scale_rewards": scale_rewards,
660-
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
661-
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
662-
}
663-
# Only include optional fields if they're set
664-
if epsilon is not None:
665-
dev_config["epsilon"] = epsilon
666-
if epsilon_high is not None:
667-
dev_config["epsilon_high"] = epsilon_high
668-
if max_negative_advantage_importance_sampling_weight is not None:
669-
dev_config["max_negative_advantage_importance_sampling_weight"] = (
670-
max_negative_advantage_importance_sampling_weight
671-
)
672-
if kimi_k2_tau is not None:
673-
dev_config["kimi_k2_tau"] = kimi_k2_tau
674-
if truncated_importance_sampling is not None:
675-
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
676-
if kl_ref_adapter_path is not None:
677-
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
678-
elif kl_penalty_reference_step is not None:
679-
ref_checkpoint_dir = get_step_checkpoint_dir(
603+
604+
resolved_kl_ref_adapter_path = kl_ref_adapter_path
605+
if (
606+
resolved_kl_ref_adapter_path is None
607+
and kl_penalty_reference_step is not None
608+
):
609+
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
680610
get_model_dir(model=model, art_path=self._path),
681611
kl_penalty_reference_step,
682612
)
683-
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
613+
config, dev_config = build_rl_train_configs(
614+
learning_rate=learning_rate,
615+
advantage_balance=advantage_balance,
616+
scale_rewards=scale_rewards,
617+
importance_sampling_level=importance_sampling_level,
618+
mask_prob_ratio=mask_prob_ratio,
619+
ppo=ppo,
620+
precalculate_logprobs=precalculate_logprobs,
621+
epsilon=epsilon,
622+
epsilon_high=epsilon_high,
623+
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
624+
kimi_k2_tau=kimi_k2_tau,
625+
kl_penalty_coef=kl_penalty_coef,
626+
allow_training_without_logprobs=allow_training_without_logprobs,
627+
plot_tensors=plot_tensors,
628+
truncated_importance_sampling=truncated_importance_sampling,
629+
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
630+
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
631+
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
632+
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
633+
)
684634

685635
# Collect metrics from training
686636
training_metrics: list[dict[str, float]] = []
@@ -690,21 +640,10 @@ async def train( # type: ignore[override]
690640
):
691641
training_metrics.append(metrics)
692642

693-
# Aggregate metrics
694-
avg_metrics = average_metric_samples(training_metrics)
695-
summary = summarize_trajectory_groups(groups_list)
696-
avg_metrics.setdefault(
697-
"time/step_trainer_s", time.monotonic() - trainer_started
698-
)
699-
avg_metrics.update(
700-
{
701-
key: value
702-
for key, value in build_training_summary_metrics(
703-
summary,
704-
include_trainable_groups=True,
705-
).items()
706-
if key not in avg_metrics
707-
}
643+
avg_metrics = aggregate_rl_training_metrics(
644+
training_metrics=training_metrics,
645+
trajectory_groups=groups_list,
646+
trainer_started=trainer_started,
708647
)
709648

710649
# Get step and checkpoint path

0 commit comments

Comments
 (0)