Skip to content

Commit 8ad8e50

Browse files
KovboclaudeFurtherAI
authored
Share training code between backends (#626)
* Refresh shared training refactor on top of ART main * Rename Megatron merge helper * Deduplicate local and shared training logic * Fix Megatron rope theta compatibility * Remove Megatron rope theta workaround * Align Unsloth SFT weight decay defaults * remove apex from no-build-isolation-package * update install script * Fix Megatron job finalization ordering * Share Megatron worker loop * Default Megatron grad accumulation by DP size * Collapse Megatron shared API into train module * Remove Megatron shared shim * Collapse Unsloth shared API into train module * Lighten Megatron orchestration imports * fix: normalize SFT loss by token count before backward pass The loss was not being divided by global_trainable_tokens before calling backward(), causing gradients to scale with batch size and grad_norm to explode to infinity during Megatron SFT training. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Revert "fix: normalize SFT loss by token count before backward pass" This reverts commit d08f2ad. * Support Megatron SFT in local backend * refactor: extract create_identity_lora as standalone function Extract the identity LoRA creation logic from MegatronService._create_identity_lora into a module-level create_identity_lora() function so it can be reused by the serverless training backend. The class method now delegates to this function. This avoids duplicating the MoE-aware identity LoRA creation logic (fused expert targets + convert_checkpoint_if_needed A/B swap) across repos. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix SFT main_grad fallback in Megatron * Fix ART lint and type issues * Simplify ty-safe optimizer access * test: drop megatron sft batch unit test * refactor: revert direct safetensors import in moe conversion * style: format megatron oracle harness * refactor: use direct safetensors import in routing replay * fix: isolate megatron optimizer states and step counts * Add SFT oracle coverage and shared grad scheduling --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: FurtherAI <FurtherAI@users.noreply.github.com>
1 parent 75a81e9 commit 8ad8e50

File tree

25 files changed

+2169
-1104
lines changed

25 files changed

+2169
-1104
lines changed

scripts/setup.sh

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,14 @@ else
5353
echo "Skipping git reset/clean (GIT_RESET_CLEAN is not true). Preserving synced working tree."
5454
fi
5555

56-
# Install astral-uv
57-
if ! command -v uv >/dev/null 2>&1; then
58-
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
59-
echo "Failed to install uv." >&2
60-
exit 1
61-
fi
62-
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
56+
# Install astral-uv (standalone version)
57+
# Always prepend standalone install path so it takes precedence over system/conda uv
58+
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
59+
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
60+
echo "Failed to install uv." >&2
61+
exit 1
6362
fi
6463

65-
# Update uv
66-
uv self update
67-
6864
# Sync the dependencies
6965
if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then
7066
uv sync --all-extras

src/art/_backend_training.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from collections.abc import Iterable
2+
import time
3+
from typing import Literal
4+
5+
from . import dev
6+
from .metrics_taxonomy import (
7+
average_metric_samples,
8+
build_training_summary_metrics,
9+
summarize_trajectory_groups,
10+
)
11+
from .trajectories import TrajectoryGroup
12+
from .types import TrainConfig
13+
14+
15+
def build_rl_train_configs(
16+
*,
17+
learning_rate: float,
18+
advantage_balance: float = 0.0,
19+
scale_rewards: bool = True,
20+
importance_sampling_level: Literal[
21+
"token", "sequence", "average", "geometric_average"
22+
] = "token",
23+
mask_prob_ratio: bool = False,
24+
ppo: bool = False,
25+
precalculate_logprobs: bool = False,
26+
epsilon: float | None = None,
27+
epsilon_high: float | None = None,
28+
max_negative_advantage_importance_sampling_weight: float | None = None,
29+
kimi_k2_tau: float | None = None,
30+
kl_penalty_coef: float = 0.0,
31+
allow_training_without_logprobs: bool | None = None,
32+
plot_tensors: bool | None = None,
33+
truncated_importance_sampling: float | None = None,
34+
scale_learning_rate_by_reward_std_dev: bool | None = None,
35+
logprob_calculation_chunk_size: int | None = None,
36+
num_trajectories_learning_rate_multiplier_power: float | None = None,
37+
kl_ref_adapter_path: str | None = None,
38+
) -> tuple[TrainConfig, dev.TrainConfig]:
39+
config = TrainConfig(
40+
learning_rate=learning_rate,
41+
kl_penalty_coef=kl_penalty_coef,
42+
)
43+
dev_config: dev.TrainConfig = {
44+
"advantage_balance": advantage_balance,
45+
"importance_sampling_level": importance_sampling_level,
46+
"kl_penalty_coef": kl_penalty_coef,
47+
"mask_prob_ratio": mask_prob_ratio,
48+
"ppo": ppo,
49+
"precalculate_logprobs": precalculate_logprobs,
50+
"scale_rewards": scale_rewards,
51+
}
52+
53+
if allow_training_without_logprobs is not None:
54+
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
55+
if plot_tensors is not None:
56+
dev_config["plot_tensors"] = plot_tensors
57+
if truncated_importance_sampling is not None:
58+
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
59+
if scale_learning_rate_by_reward_std_dev is not None:
60+
dev_config["scale_learning_rate_by_reward_std_dev"] = (
61+
scale_learning_rate_by_reward_std_dev
62+
)
63+
if logprob_calculation_chunk_size is not None:
64+
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
65+
if num_trajectories_learning_rate_multiplier_power is not None:
66+
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
67+
num_trajectories_learning_rate_multiplier_power
68+
)
69+
if epsilon is not None:
70+
dev_config["epsilon"] = epsilon
71+
if epsilon_high is not None:
72+
dev_config["epsilon_high"] = epsilon_high
73+
if max_negative_advantage_importance_sampling_weight is not None:
74+
dev_config["max_negative_advantage_importance_sampling_weight"] = (
75+
max_negative_advantage_importance_sampling_weight
76+
)
77+
if kimi_k2_tau is not None:
78+
dev_config["kimi_k2_tau"] = kimi_k2_tau
79+
if kl_ref_adapter_path is not None:
80+
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
81+
82+
return config, dev_config
83+
84+
85+
def aggregate_rl_training_metrics(
86+
*,
87+
training_metrics: list[dict[str, float]],
88+
trajectory_groups: Iterable[TrajectoryGroup],
89+
trainer_started: float,
90+
) -> dict[str, float]:
91+
groups_list = list(trajectory_groups)
92+
avg_metrics = average_metric_samples(training_metrics)
93+
summary = summarize_trajectory_groups(groups_list)
94+
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
95+
avg_metrics.update(
96+
{
97+
key: value
98+
for key, value in build_training_summary_metrics(
99+
summary,
100+
include_trainable_groups=True,
101+
).items()
102+
if key not in avg_metrics
103+
}
104+
)
105+
return avg_metrics

src/art/local/backend.py

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@
4343
from mp_actors import close_proxy, move_to_child_process
4444

4545
from .. import dev
46+
from .._backend_training import (
47+
aggregate_rl_training_metrics,
48+
build_rl_train_configs,
49+
)
4650
from ..backend import AnyTrainableModel, Backend
4751
from ..costs import build_cost_calculator, get_model_pricing
4852
from ..metrics_taxonomy import (
4953
TRAIN_GRADIENT_STEPS_KEY,
50-
average_metric_samples,
5154
build_training_summary_metrics,
5255
summarize_trajectory_groups,
5356
)
@@ -642,45 +645,36 @@ async def train( # type: ignore[override]
642645
if adam_params is not None:
643646
raise ValueError("LocalBackend requires adam_params=None.")
644647

645-
# Build config objects from explicit kwargs
646-
config = TrainConfig(
647-
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
648-
)
649-
dev_config: dev.TrainConfig = {
650-
"advantage_balance": advantage_balance,
651-
"allow_training_without_logprobs": allow_training_without_logprobs,
652-
"importance_sampling_level": importance_sampling_level,
653-
"kl_penalty_coef": kl_penalty_coef,
654-
"mask_prob_ratio": mask_prob_ratio,
655-
"plot_tensors": plot_tensors,
656-
"ppo": loss_fn == "ppo",
657-
"precalculate_logprobs": precalculate_logprobs,
658-
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
659-
"scale_rewards": scale_rewards,
660-
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
661-
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
662-
}
663-
# Only include optional fields if they're set
664-
if epsilon is not None:
665-
dev_config["epsilon"] = epsilon
666-
if epsilon_high is not None:
667-
dev_config["epsilon_high"] = epsilon_high
668-
if max_negative_advantage_importance_sampling_weight is not None:
669-
dev_config["max_negative_advantage_importance_sampling_weight"] = (
670-
max_negative_advantage_importance_sampling_weight
671-
)
672-
if kimi_k2_tau is not None:
673-
dev_config["kimi_k2_tau"] = kimi_k2_tau
674-
if truncated_importance_sampling is not None:
675-
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
676-
if kl_ref_adapter_path is not None:
677-
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
678-
elif kl_penalty_reference_step is not None:
679-
ref_checkpoint_dir = get_step_checkpoint_dir(
648+
resolved_kl_ref_adapter_path = kl_ref_adapter_path
649+
if (
650+
resolved_kl_ref_adapter_path is None
651+
and kl_penalty_reference_step is not None
652+
):
653+
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
680654
get_model_dir(model=model, art_path=self._path),
681655
kl_penalty_reference_step,
682656
)
683-
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
657+
config, dev_config = build_rl_train_configs(
658+
learning_rate=learning_rate,
659+
advantage_balance=advantage_balance,
660+
scale_rewards=scale_rewards,
661+
importance_sampling_level=importance_sampling_level,
662+
mask_prob_ratio=mask_prob_ratio,
663+
ppo=loss_fn == "ppo",
664+
precalculate_logprobs=precalculate_logprobs,
665+
epsilon=epsilon,
666+
epsilon_high=epsilon_high,
667+
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
668+
kimi_k2_tau=kimi_k2_tau,
669+
kl_penalty_coef=kl_penalty_coef,
670+
allow_training_without_logprobs=allow_training_without_logprobs,
671+
plot_tensors=plot_tensors,
672+
truncated_importance_sampling=truncated_importance_sampling,
673+
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
674+
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
675+
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
676+
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
677+
)
684678

685679
# Collect metrics from training
686680
training_metrics: list[dict[str, float]] = []
@@ -690,21 +684,10 @@ async def train( # type: ignore[override]
690684
):
691685
training_metrics.append(metrics)
692686

693-
# Aggregate metrics
694-
avg_metrics = average_metric_samples(training_metrics)
695-
summary = summarize_trajectory_groups(groups_list)
696-
avg_metrics.setdefault(
697-
"time/step_trainer_s", time.monotonic() - trainer_started
698-
)
699-
avg_metrics.update(
700-
{
701-
key: value
702-
for key, value in build_training_summary_metrics(
703-
summary,
704-
include_trainable_groups=True,
705-
).items()
706-
if key not in avg_metrics
707-
}
687+
avg_metrics = aggregate_rl_training_metrics(
688+
training_metrics=training_metrics,
689+
trajectory_groups=groups_list,
690+
trainer_started=trainer_started,
708691
)
709692

710693
# Get step and checkpoint path
@@ -822,20 +805,31 @@ async def _train_model(
822805
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
823806
)
824807
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
825-
grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences))
826-
estimated_gradient_steps = math.ceil(
808+
grad_accumulation_sequences = max(
809+
1, int(config.grad_accumulation_sequences or 1)
810+
)
811+
fallback_gradient_steps = math.ceil(
827812
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
828813
)
829-
pbar = tqdm.tqdm(total=estimated_gradient_steps, desc="train")
814+
pbar = tqdm.tqdm(total=fallback_gradient_steps, desc="train")
815+
reported_gradient_steps: int | None = None
830816
async for result in service.train(
831817
disk_packed_tensors, config, dev_config, verbose
832818
):
833-
num_gradient_steps = int(
834-
result.pop(TRAIN_GRADIENT_STEPS_KEY, estimated_gradient_steps)
835-
)
836-
assert num_gradient_steps == estimated_gradient_steps, (
837-
f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
838-
)
819+
raw_num_gradient_steps = result.pop(TRAIN_GRADIENT_STEPS_KEY, None)
820+
if raw_num_gradient_steps is not None:
821+
num_gradient_steps = int(raw_num_gradient_steps)
822+
if reported_gradient_steps is None:
823+
reported_gradient_steps = num_gradient_steps
824+
if pbar.total != num_gradient_steps:
825+
pbar.total = num_gradient_steps
826+
pbar.refresh()
827+
else:
828+
assert num_gradient_steps == reported_gradient_steps, (
829+
f"num_gradient_steps {num_gradient_steps} != reported_gradient_steps {reported_gradient_steps}"
830+
)
831+
else:
832+
num_gradient_steps = reported_gradient_steps or fallback_gradient_steps
839833
yield {
840834
**base_metrics,
841835
**result,
@@ -882,10 +876,13 @@ async def _train_sft(
882876
)
883877
tokenizer = self._tokenizers[model.base_model]
884878

885-
# Determine batch_size
886-
batch_size = config.batch_size
887-
if batch_size == "auto":
888-
batch_size = 2 # Default to 2 for SFT
879+
from ..utils.sft import resolve_sft_batch_size
880+
881+
batch_size = resolve_sft_batch_size(
882+
batch_size=config.batch_size,
883+
default_batch_size=self._default_sft_batch_size(),
884+
)
885+
service_config = config.model_copy(update={"batch_size": batch_size})
889886

890887
# Auto-detect instruction/response parts from model
891888
from ..utils.model_config import get_instruction_response_parts
@@ -931,7 +928,7 @@ async def _train_sft(
931928
total_trajectories = len(trajectory_list)
932929
batch_count = 0
933930

934-
async for result in service.train_sft(batches, verbose):
931+
async for result in service.train_sft(batches, service_config, verbose):
935932
pbar.update(1)
936933
pbar.set_postfix({"loss": f"{result.get('loss/train', 0):.4f}"})
937934
batch_count += 1
@@ -953,6 +950,9 @@ async def _train_sft(
953950
if verbose:
954951
print("_train_sft complete")
955952

953+
def _default_sft_batch_size(self) -> int:
954+
return 2
955+
956956
# ------------------------------------------------------------------
957957
# Experimental support for S3
958958
# ------------------------------------------------------------------

src/art/local/service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ def train(
3333
def train_sft(
3434
self,
3535
batches: list[SFTBatch],
36+
config: types.TrainSFTConfig,
3637
verbose: bool = False,
3738
) -> AsyncIterator[dict[str, float]]:
3839
"""Train using SFT on pre-computed batches.
3940
4041
Args:
4142
batches: List of SFTBatch objects to train on.
43+
config: SFT batch/grad-accumulation configuration.
4244
verbose: Whether to print detailed logs.
4345
4446
Yields:

src/art/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from . import dev
99

1010
if TYPE_CHECKING:
11-
from art.unsloth.service import TrainInputs
11+
from art.preprocessing.inputs import TrainInputs
1212

1313

1414
class Loss(BaseModel):

src/art/megatron/backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
3737
process_name="megatron-service",
3838
)
3939
return self._services[model.name]
40+
41+
def _default_sft_batch_size(self) -> int:
42+
import torch
43+
44+
num_gpus = max(int(torch.cuda.device_count()), 1)
45+
tensor_parallel_size = min(2, num_gpus)
46+
return max(num_gpus // tensor_parallel_size, 1)

0 commit comments

Comments
 (0)