Skip to content

Commit f6cd445

Browse files
committed
Lighten Megatron orchestration imports
1 parent 0fa9a2b commit f6cd445

File tree

4 files changed

+110
-105
lines changed

4 files changed

+110
-105
lines changed

src/art/megatron/jobs.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Literal
1+
from typing import Any, Literal
22

33
from pydantic import BaseModel
44

5-
from .. import dev, types
5+
from .. import types
66
from ..preprocessing.pack import DiskPackedTensors
7-
from .routing_replay import MoeRoutingReplayBundle
87

98
DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl"
109
DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs"
@@ -16,18 +15,12 @@ class MegatronTrainingJob(BaseModel):
1615
optimizer_state_path: str
1716
disk_packed_tensors: DiskPackedTensors
1817
config: types.TrainConfig
19-
experimental_config: dev.TrainConfig
18+
experimental_config: dict[str, Any]
2019
moe_routing_replay_path: str | None = None
2120
moe_routing_replay_strict: bool = True
2221
log_path: str = DEFAULT_TRAINING_LOG_PATH
2322

2423

25-
MegatronTrainingJob.model_rebuild(
26-
force=True,
27-
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
28-
)
29-
30-
3124
class MegatronSFTTrainingJob(BaseModel):
3225
job_type: Literal["sft"] = "sft"
3326
lora_path: str

src/art/megatron/merge.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import importlib
2+
import json
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import torch
7+
8+
safetensors = importlib.import_module("safetensors")
9+
safetensors_torch = importlib.import_module("safetensors.torch")
10+
safe_open = safetensors.safe_open
11+
save_file = safetensors_torch.save_file
12+
13+
14+
def merge_lora_adapter(lora_path: str) -> None:
15+
base_dir = Path(lora_path)
16+
shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors"))
17+
if not shard_filenames:
18+
return
19+
20+
shard_files_by_suffix = {
21+
path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path
22+
for path in shard_filenames
23+
}
24+
manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json"))
25+
manifest_files_by_suffix = {
26+
path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path
27+
for path in manifest_filenames
28+
}
29+
30+
if set(shard_files_by_suffix) != set(manifest_files_by_suffix):
31+
raise RuntimeError(
32+
"Shard/manifest coverage mismatch: "
33+
f"shards={sorted(shard_files_by_suffix)}, "
34+
f"manifests={sorted(manifest_files_by_suffix)}"
35+
)
36+
37+
entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {}
38+
for suffix in sorted(shard_files_by_suffix):
39+
shard_path = shard_files_by_suffix[suffix]
40+
manifest_path = manifest_files_by_suffix[suffix]
41+
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
42+
shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file)
43+
with safe_open(shard_path, framework="pt") as file:
44+
shard_tensors = {key: file.get_tensor(key) for key in file.keys()}
45+
46+
if set(shard_tensors) != set(shard_manifest):
47+
raise RuntimeError(
48+
f"Tensor/manifest key mismatch for shard suffix={suffix}: "
49+
f"tensor_keys={sorted(shard_tensors)}, "
50+
f"manifest_keys={sorted(shard_manifest)}"
51+
)
52+
for key, tensor in shard_tensors.items():
53+
entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor))
54+
55+
adapter_model: dict[str, torch.Tensor] = {}
56+
for key, key_entries in entries_by_key.items():
57+
first_manifest = key_entries[0][0]
58+
sharded = bool(first_manifest["sharded"])
59+
shard_world_size = int(first_manifest["shard_world_size"])
60+
for manifest_entry, _tensor in key_entries:
61+
if bool(manifest_entry["sharded"]) != sharded:
62+
raise RuntimeError(f"Inconsistent sharded flag for key={key}")
63+
if int(manifest_entry["shard_world_size"]) != shard_world_size:
64+
raise RuntimeError(f"Inconsistent shard world size for key={key}")
65+
66+
if not sharded:
67+
if len(key_entries) != 1:
68+
raise RuntimeError(
69+
f"Replicated key={key} expected 1 shard, got {len(key_entries)}"
70+
)
71+
tensor = key_entries[0][1]
72+
else:
73+
shard_rank_to_tensor: dict[int, torch.Tensor] = {}
74+
for manifest_entry, shard_tensor in key_entries:
75+
shard_rank = int(manifest_entry["shard_rank"])
76+
if shard_rank in shard_rank_to_tensor:
77+
raise RuntimeError(
78+
f"Duplicate shard_rank={shard_rank} for key={key}"
79+
)
80+
shard_rank_to_tensor[shard_rank] = shard_tensor
81+
82+
expected_shard_ranks = set(range(shard_world_size))
83+
if set(shard_rank_to_tensor) != expected_shard_ranks:
84+
raise RuntimeError(
85+
f"Shard rank coverage mismatch for key={key}: "
86+
f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}"
87+
)
88+
89+
ordered_shards = [
90+
shard_rank_to_tensor[shard_rank]
91+
for shard_rank in range(shard_world_size)
92+
]
93+
concat_dim = 1 if "lora_A" in key else 0
94+
tensor = torch.cat(ordered_shards, dim=concat_dim)
95+
adapter_model[key] = tensor
96+
97+
adapter_model_path = base_dir / "adapter_model.safetensors"
98+
save_file(adapter_model, adapter_model_path)
99+
for filename in shard_filenames:
100+
filename.unlink()
101+
for filename in manifest_filenames:
102+
filename.unlink()

src/art/megatron/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import shlex
1010
import shutil
1111
import subprocess
12-
from typing import Any, AsyncIterator
12+
from typing import Any, AsyncIterator, cast
1313

1414
from peft.tuners.lora.config import LoraConfig
1515
import torch
@@ -32,7 +32,7 @@
3232
DEFAULT_VLLM_WAKE_LOCK_PATH,
3333
MegatronTrainingJob,
3434
)
35-
from .train import merge_lora_adapter
35+
from .merge import merge_lora_adapter
3636

3737
safetensors = importlib.import_module("safetensors")
3838
safe_open = safetensors.safe_open
@@ -283,7 +283,7 @@ async def train(
283283
optimizer_state_path=self._optimizer_state_path,
284284
disk_packed_tensors=disk_packed_tensors,
285285
config=config,
286-
experimental_config=_config,
286+
experimental_config=cast(dict[str, Any], _config),
287287
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
288288
moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True),
289289
log_path=os.path.join(

src/art/megatron/train.py

Lines changed: 2 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
MegatronTrainingJob,
4343
)
4444
from art.megatron.lora import apply_lora_adapters
45+
from art.megatron.merge import merge_lora_adapter
4546
from art.megatron.offload import (
4647
OffloadState,
4748
clear_optimizer_state,
@@ -402,7 +403,7 @@ def run_megatron_rl_job(
402403
learning_rate=job.config.learning_rate,
403404
inputs=micro_inputs,
404405
config=job.config,
405-
experimental_config=job.experimental_config,
406+
experimental_config=cast(dev.TrainConfig, job.experimental_config),
406407
ref_logprobs=None,
407408
step_index=step_index,
408409
sample_index=micro_indices,
@@ -587,97 +588,6 @@ def _job_cleanup_path(job: MegatronJob) -> str:
587588
return job.disk_packed_tensors["dir"]
588589

589590

590-
def merge_lora_adapter(lora_path: str) -> None:
591-
base_dir = Path(lora_path)
592-
shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors"))
593-
if not shard_filenames:
594-
return
595-
596-
shard_files_by_suffix = {
597-
path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path
598-
for path in shard_filenames
599-
}
600-
manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json"))
601-
manifest_files_by_suffix = {
602-
path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path
603-
for path in manifest_filenames
604-
}
605-
606-
if set(shard_files_by_suffix) != set(manifest_files_by_suffix):
607-
raise RuntimeError(
608-
"Shard/manifest coverage mismatch: "
609-
f"shards={sorted(shard_files_by_suffix)}, "
610-
f"manifests={sorted(manifest_files_by_suffix)}"
611-
)
612-
613-
entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {}
614-
for suffix in sorted(shard_files_by_suffix):
615-
shard_path = shard_files_by_suffix[suffix]
616-
manifest_path = manifest_files_by_suffix[suffix]
617-
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
618-
shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file)
619-
with safe_open(shard_path, framework="pt") as file:
620-
shard_tensors = {key: file.get_tensor(key) for key in file.keys()}
621-
622-
if set(shard_tensors) != set(shard_manifest):
623-
raise RuntimeError(
624-
f"Tensor/manifest key mismatch for shard suffix={suffix}: "
625-
f"tensor_keys={sorted(shard_tensors)}, "
626-
f"manifest_keys={sorted(shard_manifest)}"
627-
)
628-
for key, tensor in shard_tensors.items():
629-
entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor))
630-
631-
adapter_model: dict[str, torch.Tensor] = {}
632-
for key, key_entries in entries_by_key.items():
633-
first_manifest = key_entries[0][0]
634-
sharded = bool(first_manifest["sharded"])
635-
shard_world_size = int(first_manifest["shard_world_size"])
636-
for manifest_entry, _tensor in key_entries:
637-
if bool(manifest_entry["sharded"]) != sharded:
638-
raise RuntimeError(f"Inconsistent sharded flag for key={key}")
639-
if int(manifest_entry["shard_world_size"]) != shard_world_size:
640-
raise RuntimeError(f"Inconsistent shard world size for key={key}")
641-
642-
if not sharded:
643-
if len(key_entries) != 1:
644-
raise RuntimeError(
645-
f"Replicated key={key} expected 1 shard, got {len(key_entries)}"
646-
)
647-
tensor = key_entries[0][1]
648-
else:
649-
shard_rank_to_tensor: dict[int, torch.Tensor] = {}
650-
for manifest_entry, shard_tensor in key_entries:
651-
shard_rank = int(manifest_entry["shard_rank"])
652-
if shard_rank in shard_rank_to_tensor:
653-
raise RuntimeError(
654-
f"Duplicate shard_rank={shard_rank} for key={key}"
655-
)
656-
shard_rank_to_tensor[shard_rank] = shard_tensor
657-
658-
expected_shard_ranks = set(range(shard_world_size))
659-
if set(shard_rank_to_tensor) != expected_shard_ranks:
660-
raise RuntimeError(
661-
f"Shard rank coverage mismatch for key={key}: "
662-
f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}"
663-
)
664-
665-
ordered_shards = [
666-
shard_rank_to_tensor[shard_rank]
667-
for shard_rank in range(shard_world_size)
668-
]
669-
concat_dim = 1 if "lora_A" in key else 0
670-
tensor = torch.cat(ordered_shards, dim=concat_dim)
671-
adapter_model[key] = tensor
672-
673-
adapter_model_path = base_dir / "adapter_model.safetensors"
674-
save_file(adapter_model, adapter_model_path)
675-
for filename in shard_filenames:
676-
filename.unlink()
677-
for filename in manifest_filenames:
678-
filename.unlink()
679-
680-
681591
def _load_sft_batch_from_disk(
682592
batch_dir: str,
683593
) -> tuple[dict[str, Any], list[dict[str, torch.Tensor]]]:

0 commit comments

Comments
 (0)