Skip to content

Commit c1abe2e

Browse files
committed
Refactor shared training code for Megatron backends
1 parent 7ee516e commit c1abe2e

File tree

8 files changed

+820
-413
lines changed

8 files changed

+820
-413
lines changed

src/art/_backend_training.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from collections.abc import Iterable
2+
import time
3+
from typing import Literal
4+
5+
from . import dev
6+
from .metrics_taxonomy import (
7+
average_metric_samples,
8+
build_training_summary_metrics,
9+
summarize_trajectory_groups,
10+
)
11+
from .trajectories import TrajectoryGroup
12+
from .types import TrainConfig
13+
14+
15+
def build_rl_train_configs(
16+
*,
17+
learning_rate: float,
18+
advantage_balance: float = 0.0,
19+
scale_rewards: bool = True,
20+
importance_sampling_level: Literal[
21+
"token", "sequence", "average", "geometric_average"
22+
] = "token",
23+
mask_prob_ratio: bool = False,
24+
ppo: bool = False,
25+
precalculate_logprobs: bool = False,
26+
epsilon: float | None = None,
27+
epsilon_high: float | None = None,
28+
max_negative_advantage_importance_sampling_weight: float | None = None,
29+
kimi_k2_tau: float | None = None,
30+
kl_penalty_coef: float = 0.0,
31+
allow_training_without_logprobs: bool | None = None,
32+
plot_tensors: bool | None = None,
33+
truncated_importance_sampling: float | None = None,
34+
scale_learning_rate_by_reward_std_dev: bool | None = None,
35+
logprob_calculation_chunk_size: int | None = None,
36+
num_trajectories_learning_rate_multiplier_power: float | None = None,
37+
kl_ref_adapter_path: str | None = None,
38+
) -> tuple[TrainConfig, dev.TrainConfig]:
39+
config = TrainConfig(
40+
learning_rate=learning_rate,
41+
kl_penalty_coef=kl_penalty_coef,
42+
)
43+
dev_config: dev.TrainConfig = {
44+
"advantage_balance": advantage_balance,
45+
"importance_sampling_level": importance_sampling_level,
46+
"kl_penalty_coef": kl_penalty_coef,
47+
"mask_prob_ratio": mask_prob_ratio,
48+
"ppo": ppo,
49+
"precalculate_logprobs": precalculate_logprobs,
50+
"scale_rewards": scale_rewards,
51+
}
52+
53+
if allow_training_without_logprobs is not None:
54+
dev_config["allow_training_without_logprobs"] = (
55+
allow_training_without_logprobs
56+
)
57+
if plot_tensors is not None:
58+
dev_config["plot_tensors"] = plot_tensors
59+
if truncated_importance_sampling is not None:
60+
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
61+
if scale_learning_rate_by_reward_std_dev is not None:
62+
dev_config["scale_learning_rate_by_reward_std_dev"] = (
63+
scale_learning_rate_by_reward_std_dev
64+
)
65+
if logprob_calculation_chunk_size is not None:
66+
dev_config["logprob_calculation_chunk_size"] = (
67+
logprob_calculation_chunk_size
68+
)
69+
if num_trajectories_learning_rate_multiplier_power is not None:
70+
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
71+
num_trajectories_learning_rate_multiplier_power
72+
)
73+
if epsilon is not None:
74+
dev_config["epsilon"] = epsilon
75+
if epsilon_high is not None:
76+
dev_config["epsilon_high"] = epsilon_high
77+
if max_negative_advantage_importance_sampling_weight is not None:
78+
dev_config["max_negative_advantage_importance_sampling_weight"] = (
79+
max_negative_advantage_importance_sampling_weight
80+
)
81+
if kimi_k2_tau is not None:
82+
dev_config["kimi_k2_tau"] = kimi_k2_tau
83+
if kl_ref_adapter_path is not None:
84+
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
85+
86+
return config, dev_config
87+
88+
89+
def aggregate_rl_training_metrics(
90+
*,
91+
training_metrics: list[dict[str, float]],
92+
trajectory_groups: Iterable[TrajectoryGroup],
93+
trainer_started: float,
94+
) -> dict[str, float]:
95+
groups_list = list(trajectory_groups)
96+
avg_metrics = average_metric_samples(training_metrics)
97+
summary = summarize_trajectory_groups(groups_list)
98+
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
99+
avg_metrics.update(
100+
{
101+
key: value
102+
for key, value in build_training_summary_metrics(
103+
summary,
104+
include_trainable_groups=True,
105+
).items()
106+
if key not in avg_metrics
107+
}
108+
)
109+
return avg_metrics

src/art/local/backend.py

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@
4343
from mp_actors import close_proxy, move_to_child_process
4444

4545
from .. import dev
46+
from .._backend_training import (
47+
aggregate_rl_training_metrics,
48+
build_rl_train_configs,
49+
)
4650
from ..backend import AnyTrainableModel, Backend
4751
from ..metrics_taxonomy import (
4852
TRAIN_GRADIENT_STEPS_KEY,
49-
average_metric_samples,
5053
build_training_summary_metrics,
5154
summarize_trajectory_groups,
5255
)
@@ -598,45 +601,36 @@ async def train( # type: ignore[override]
598601
"""
599602
groups_list = list(trajectory_groups)
600603

601-
# Build config objects from explicit kwargs
602-
config = TrainConfig(
603-
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
604-
)
605-
dev_config: dev.TrainConfig = {
606-
"advantage_balance": advantage_balance,
607-
"allow_training_without_logprobs": allow_training_without_logprobs,
608-
"importance_sampling_level": importance_sampling_level,
609-
"kl_penalty_coef": kl_penalty_coef,
610-
"mask_prob_ratio": mask_prob_ratio,
611-
"plot_tensors": plot_tensors,
612-
"ppo": ppo,
613-
"precalculate_logprobs": precalculate_logprobs,
614-
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
615-
"scale_rewards": scale_rewards,
616-
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
617-
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
618-
}
619-
# Only include optional fields if they're set
620-
if epsilon is not None:
621-
dev_config["epsilon"] = epsilon
622-
if epsilon_high is not None:
623-
dev_config["epsilon_high"] = epsilon_high
624-
if max_negative_advantage_importance_sampling_weight is not None:
625-
dev_config["max_negative_advantage_importance_sampling_weight"] = (
626-
max_negative_advantage_importance_sampling_weight
627-
)
628-
if kimi_k2_tau is not None:
629-
dev_config["kimi_k2_tau"] = kimi_k2_tau
630-
if truncated_importance_sampling is not None:
631-
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
632-
if kl_ref_adapter_path is not None:
633-
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
634-
elif kl_penalty_reference_step is not None:
635-
ref_checkpoint_dir = get_step_checkpoint_dir(
604+
resolved_kl_ref_adapter_path = kl_ref_adapter_path
605+
if (
606+
resolved_kl_ref_adapter_path is None
607+
and kl_penalty_reference_step is not None
608+
):
609+
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
636610
get_model_dir(model=model, art_path=self._path),
637611
kl_penalty_reference_step,
638612
)
639-
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
613+
config, dev_config = build_rl_train_configs(
614+
learning_rate=learning_rate,
615+
advantage_balance=advantage_balance,
616+
scale_rewards=scale_rewards,
617+
importance_sampling_level=importance_sampling_level,
618+
mask_prob_ratio=mask_prob_ratio,
619+
ppo=ppo,
620+
precalculate_logprobs=precalculate_logprobs,
621+
epsilon=epsilon,
622+
epsilon_high=epsilon_high,
623+
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
624+
kimi_k2_tau=kimi_k2_tau,
625+
kl_penalty_coef=kl_penalty_coef,
626+
allow_training_without_logprobs=allow_training_without_logprobs,
627+
plot_tensors=plot_tensors,
628+
truncated_importance_sampling=truncated_importance_sampling,
629+
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
630+
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
631+
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
632+
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
633+
)
640634

641635
# Collect metrics from training
642636
training_metrics: list[dict[str, float]] = []
@@ -646,21 +640,10 @@ async def train( # type: ignore[override]
646640
):
647641
training_metrics.append(metrics)
648642

649-
# Aggregate metrics
650-
avg_metrics = average_metric_samples(training_metrics)
651-
summary = summarize_trajectory_groups(groups_list)
652-
avg_metrics.setdefault(
653-
"time/step_trainer_s", time.monotonic() - trainer_started
654-
)
655-
avg_metrics.update(
656-
{
657-
key: value
658-
for key, value in build_training_summary_metrics(
659-
summary,
660-
include_trainable_groups=True,
661-
).items()
662-
if key not in avg_metrics
663-
}
643+
avg_metrics = aggregate_rl_training_metrics(
644+
training_metrics=training_metrics,
645+
trajectory_groups=groups_list,
646+
trainer_started=trainer_started,
664647
)
665648

666649
# Get step and checkpoint path

src/art/megatron/jobs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Literal
2+
3+
from pydantic import BaseModel
4+
5+
from .. import dev, types
6+
from ..preprocessing.pack import DiskPackedTensors
7+
8+
DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl"
9+
DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs"
10+
DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking"
11+
12+
13+
class MegatronTrainingJob(BaseModel):
14+
lora_path: str
15+
optimizer_state_path: str
16+
disk_packed_tensors: DiskPackedTensors
17+
config: types.TrainConfig
18+
experimental_config: dev.TrainConfig
19+
log_path: str = DEFAULT_TRAINING_LOG_PATH
20+
21+
22+
class MegatronSFTTrainingJob(BaseModel):
23+
job_type: Literal["sft"] = "sft"
24+
lora_path: str
25+
optimizer_state_path: str
26+
sft_data_dir: str
27+
num_batches: int
28+
learning_rates: list[float]
29+
weight_decay: float = 0.0
30+
max_grad_norm: float = 1.0
31+
log_path: str = DEFAULT_TRAINING_LOG_PATH

src/art/megatron/runtime_env.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
4+
def _set_cache_dir(env_var: str, default_path: str) -> None:
5+
if not os.environ.get(env_var):
6+
os.environ[env_var] = os.path.expanduser(default_path)
7+
os.makedirs(os.environ[env_var], exist_ok=True)
8+
9+
10+
def configure_megatron_runtime_env() -> None:
11+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
12+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13+
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
14+
_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor")
15+
_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache")

src/art/megatron/service.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, AsyncIterator
1111

1212
from peft.tuners.lora.config import LoraConfig
13-
from pydantic import BaseModel
1413
from safetensors import safe_open
1514
from safetensors.torch import load_file, save_file
1615
import torch
@@ -26,16 +25,11 @@
2625
from ..utils.get_model_step import get_step_from_dir
2726
from ..utils.output_dirs import get_step_checkpoint_dir
2827
from ..vllm import get_llm, openai_server_task, run_on_workers
29-
30-
31-
class MegatronTrainingJob(BaseModel):
32-
"""Job format for communication with train.py"""
33-
34-
lora_path: str
35-
optimizer_state_path: str
36-
disk_packed_tensors: DiskPackedTensors
37-
config: types.TrainConfig
38-
experimental_config: dev.TrainConfig
28+
from .jobs import (
29+
DEFAULT_JOBS_DIR,
30+
DEFAULT_TRAINING_LOG_PATH,
31+
MegatronTrainingJob,
32+
)
3933

4034

4135
@dataclass
@@ -236,34 +230,35 @@ async def train(
236230

237231
self._optimizer_state_path = self._get_optimizer_state_path()
238232

239-
jobs_dir = "/tmp/megatron_training_jobs"
240-
os.makedirs(jobs_dir, exist_ok=True)
241-
for job_name in os.listdir(jobs_dir):
233+
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
234+
for job_name in os.listdir(DEFAULT_JOBS_DIR):
242235
if job_name.endswith(".json"):
243-
os.remove(os.path.join(jobs_dir, job_name))
236+
os.remove(os.path.join(DEFAULT_JOBS_DIR, job_name))
244237
job = MegatronTrainingJob(
245238
lora_path=lora_path,
246239
optimizer_state_path=self._optimizer_state_path,
247240
disk_packed_tensors=disk_packed_tensors,
248241
config=config,
249242
experimental_config=_config,
250243
)
251-
job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json")
244+
job_path = os.path.join(
245+
DEFAULT_JOBS_DIR, f"{datetime.datetime.now().isoformat()}.json"
246+
)
252247
with open(job_path, "w") as f:
253248
f.write(job.model_dump_json())
254249

255250
num_lines = 0
256251
while True:
257252
await asyncio.sleep(0.1)
258253
try:
259-
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:
254+
with open(DEFAULT_TRAINING_LOG_PATH, "a+") as log_file:
260255
log_file.seek(0)
261256
lines = log_file.readlines()[num_lines:]
262257
for line in lines:
263258
if line := line.strip():
264259
if line == "all done":
265260
self._merge_lora_adapter(lora_path)
266-
os.remove("/tmp/megatron_training_log.jsonl")
261+
os.remove(DEFAULT_TRAINING_LOG_PATH)
267262
break
268263
num_lines += 1
269264
yield json.loads(line)

0 commit comments

Comments
 (0)