From cd15449d9bfc9fd2f6e76ee00df8c77294e687f9 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Tue, 31 Mar 2026 16:58:43 -0700 Subject: [PATCH 1/9] refactor: move TCP port helper to shared utils --- src/art/unsloth/service.py | 13 +++++-------- src/art/utils/network.py | 8 ++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 src/art/utils/network.py diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index e7b79958..937217df 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -6,7 +6,6 @@ import json import logging import os -import socket import subprocess import sys from typing import Any, AsyncIterator, Literal, cast @@ -25,6 +24,7 @@ from ..preprocessing.tokenize import SFTBatch from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir +from ..utils.network import find_free_tcp_port from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers from .train import ( @@ -83,12 +83,6 @@ def save_checkpoint( return checkpoint_dir -def _find_free_tcp_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) - - def _normalize_merged_checkpoint_name(name: str) -> str: # PEFT wraps adapted modules under `.base_layer`, but vLLM expects the # original checkpoint parameter names during update_weights(). @@ -98,6 +92,9 @@ def _normalize_merged_checkpoint_name(name: str) -> str: return normalized +_find_free_tcp_port = find_free_tcp_port + + # ============================================================================ # Service # ============================================================================ @@ -286,7 +283,7 @@ async def _init_merged_weight_transfer(self) -> None: ) from exc inference_world_size = int(world_size_response.json()["world_size"]) - master_port = _find_free_tcp_port() + master_port = find_free_tcp_port() init_info = { "master_address": "127.0.0.1", "master_port": master_port, diff --git a/src/art/utils/network.py b/src/art/utils/network.py new file mode 100644 index 00000000..7b39d187 --- /dev/null +++ b/src/art/utils/network.py @@ -0,0 +1,8 @@ +import socket +from typing import cast + + +def find_free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return cast(int, sock.getsockname()[1]) From 5682cf06d9a87e1ef76928057ace48ede9edba80 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Tue, 31 Mar 2026 16:59:39 -0700 Subject: [PATCH 2/9] feat: add dedicated merged mode to Megatron backend --- src/art/megatron/jobs.py | 32 +- src/art/megatron/provider.py | 1 + src/art/megatron/service.py | 160 +++++++-- src/art/megatron/train.py | 471 +++++++++++++++++++++++++- tests/unit/test_megatron_dedicated.py | 162 +++++++++ 5 files changed, 797 insertions(+), 29 deletions(-) diff --git a/src/art/megatron/jobs.py b/src/art/megatron/jobs.py index 788fe1f3..bd20d074 100644 --- a/src/art/megatron/jobs.py +++ b/src/art/megatron/jobs.py @@ -21,6 +21,31 @@ class MegatronTrainingJob(BaseModel): log_path: str = DEFAULT_TRAINING_LOG_PATH +class MergedWeightTransferInitInfo(BaseModel): + master_address: str + master_port: int + rank_offset: int + world_size: int + + +class MergedWeightTransferSpec(BaseModel): + init_info: MergedWeightTransferInitInfo + vllm_base_url: str + served_model_name: str + + +class MegatronMergedTrainJob(MegatronTrainingJob): + job_type: Literal["merged"] = "merged" + merged_weight_transfer: MergedWeightTransferSpec + + +class MegatronSyncJob(BaseModel): + job_type: Literal["sync"] = "sync" + lora_path: str + merged_weight_transfer: MergedWeightTransferSpec + log_path: str = DEFAULT_TRAINING_LOG_PATH + + class MegatronSFTTrainingJob(BaseModel): job_type: Literal["sft"] = "sft" lora_path: str @@ -35,4 +60,9 @@ class MegatronSFTTrainingJob(BaseModel): log_path: str = DEFAULT_TRAINING_LOG_PATH -MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob +MegatronJob = ( + MegatronTrainingJob + | MegatronMergedTrainJob + | MegatronSyncJob + | MegatronSFTTrainingJob +) diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index 8f01c799..eef8679a 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -328,6 +328,7 @@ def get_provider( ) ) provider = bridge.to_megatron_provider() + setattr(provider, "art_bridge", bridge) base_layer_spec = provider.transformer_layer_spec def _flex_attention_layer_spec( diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 80723fb0..485f290b 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -8,6 +8,7 @@ from pathlib import Path import shlex import shutil +import signal import socket import subprocess import sys @@ -28,12 +29,17 @@ from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir +from ..utils.network import find_free_tcp_port from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, openai_server_task, run_on_workers from .client import create_megatron_job_paths, stream_megatron_job, write_megatron_job from .jobs import ( + MegatronMergedTrainJob, MegatronSFTTrainingJob, + MegatronSyncJob, MegatronTrainingJob, + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, ) from .lora import LORA_ALPHA, LORA_RANK from .sft_batches import materialize_sft_batches @@ -148,6 +154,10 @@ class MegatronService: _vllm_log_file: Any = field(default=None, repr=False) _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 + _merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = field( + default=None, + repr=False, + ) @property def is_dedicated(self) -> bool: @@ -247,6 +257,15 @@ def _ensure_lora_adapter_config( return self._default_lora_adapter_config().save_pretrained(lora_path) + def _build_merged_weight_transfer_spec(self, step: int) -> MergedWeightTransferSpec: + init_info = self._merged_weight_transfer_init_info + assert init_info is not None + return MergedWeightTransferSpec( + init_info=init_info, + vllm_base_url=self._vllm_base_url, + served_model_name=f"{self.model_name}@{step}", + ) + def _resolve_active_lora_path(self) -> str: lora_path = get_last_checkpoint_dir(self.output_dir) if lora_path is None: @@ -254,10 +273,43 @@ def _resolve_active_lora_path(self) -> str: self._latest_step = 0 else: self._latest_step = get_step_from_dir(self.output_dir) - self._ensure_identity_lora(lora_path) + if self.rollout_weights_mode == "lora": + self._ensure_identity_lora(lora_path) self._ensure_lora_adapter_config(lora_path) return lora_path + async def _set_served_model_name(self, step: int) -> None: + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/art/set_served_model_name", + json={"name": f"{self.model_name}@{step}"}, + timeout=30.0, + ) + response.raise_for_status() + self._latest_step = step + + async def _init_merged_weight_transfer(self) -> None: + import httpx + + if self._merged_weight_transfer_init_info is not None: + return + assert len(self.config["trainer_gpu_ids"]) == 1 + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self._vllm_base_url}/get_world_size", + timeout=30.0, + ) + response.raise_for_status() + inference_world_size = int(response.json()["world_size"]) + self._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=find_free_tcp_port(), + rank_offset=1, + world_size=inference_world_size + 1, + ) + async def _start_vllm_subprocess( self, lora_path: str, @@ -285,8 +337,13 @@ async def _start_vllm_subprocess( if config and "engine_args" in config: engine_args.update(dict(config["engine_args"])) engine_args.setdefault("generation_config", "vllm") - engine_args["enable_lora"] = True - engine_args.setdefault("max_loras", 2) + if self.rollout_weights_mode == "merged": + engine_args["weight_transfer_config"] = {"backend": "nccl"} + engine_args.pop("enable_lora", None) + engine_args.pop("max_loras", None) + else: + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) for key in ("model", "served_model_name", "enable_sleep_mode"): engine_args.pop(key, None) @@ -366,6 +423,25 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: response.raise_for_status() self._latest_step = step + async def _sync_dedicated_merged_weights( + self, + *, + lora_path: str, + step: int, + ) -> None: + await self._ensure_megatron_running() + await self._init_merged_weight_transfer() + job_path, log_path = self._create_megatron_job_paths() + job = MegatronSyncJob( + lora_path=lora_path, + merged_weight_transfer=self._build_merged_weight_transfer_spec(step), + log_path=log_path, + ) + write_megatron_job(job, job_path=job_path) + async for _ in stream_megatron_job(job, job_path=job_path): + pass + self._latest_step = step + def _stop_vllm_subprocess(self) -> None: if self._vllm_process is not None: self._vllm_process.terminate() @@ -378,12 +454,13 @@ def _stop_vllm_subprocess(self) -> None: if self._vllm_log_file is not None: self._vllm_log_file.close() self._vllm_log_file = None + self._merged_weight_transfer_init_info = None def _stop_megatron_process(self) -> None: if self._megatron_process is None: return if self._megatron_process.returncode is None: - self._megatron_process.terminate() + os.killpg(os.getpgid(self._megatron_process.pid), signal.SIGTERM) self._megatron_process = None async def _add_lora_aliases( @@ -402,8 +479,10 @@ async def _add_lora_aliases( async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: if self.is_dedicated: - assert self.rollout_weights_mode == "lora" - await self._reload_adapter(checkpoint_dir, step) + if self.rollout_weights_mode == "merged": + await self._set_served_model_name(step) + else: + await self._reload_adapter(checkpoint_dir, step) return llm = await self.llm await llm.pause_generation() @@ -458,6 +537,7 @@ async def _ensure_megatron_running(self) -> None: command, cwd=str(project_root), env=launch_env, + start_new_session=True, ) def _clear_pending_jobs(self) -> None: @@ -535,9 +615,14 @@ async def start_openai_server( lora_path = self._resolve_active_lora_path() if self.is_dedicated: - assert self.rollout_weights_mode == "lora" port = (config or {}).get("server_args", {}).get("port", 8000) - return await self._start_vllm_subprocess(lora_path, port, config) + location = await self._start_vllm_subprocess(lora_path, port, config) + if self.rollout_weights_mode == "merged": + await self._sync_dedicated_merged_weights( + lora_path=lora_path, + step=self._latest_step, + ) + return location lora_path_for_server = ( lora_path if self._adapter_has_weights(lora_path) else None @@ -575,7 +660,6 @@ async def train( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: if self.is_dedicated: - assert self.rollout_weights_mode == "lora" await self._ensure_megatron_running() lora_path = self._resolve_active_lora_path() @@ -586,24 +670,56 @@ async def train( "MegatronService subprocess jobs must use moe_routing_replay_path." ) job_path, log_path = self._create_megatron_job_paths() - job = MegatronTrainingJob( - lora_path=lora_path, - optimizer_state_path=self._get_optimizer_state_path("rl"), - disk_packed_tensors=disk_packed_tensors, - config=config, - experimental_config=cast(dict[str, Any], _config), - moe_routing_replay_path=_config.get("moe_routing_replay_path"), - moe_routing_replay_strict=_config.get( - "moe_routing_replay_strict", True - ), - log_path=log_path, - ) + next_step = self._latest_step + 1 + if self.rollout_weights_mode == "merged": + await self._init_merged_weight_transfer() + job = MegatronMergedTrainJob( + lora_path=lora_path, + optimizer_state_path=self._get_optimizer_state_path("rl"), + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=cast(dict[str, Any], _config), + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", True + ), + merged_weight_transfer=self._build_merged_weight_transfer_spec( + next_step + ), + log_path=log_path, + ) + else: + job = MegatronTrainingJob( + lora_path=lora_path, + optimizer_state_path=self._get_optimizer_state_path("rl"), + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=cast(dict[str, Any], _config), + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", True + ), + log_path=log_path, + ) write_megatron_job(job, job_path=job_path) async for result in stream_megatron_job(job, job_path=job_path): yield {key: float(value) for key, value in result.items()} - await self._publish_dedicated_training_checkpoint(lora_path=lora_path) + if self.rollout_weights_mode == "merged": + new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) + os.makedirs(new_checkpoint_dir, exist_ok=True) + shutil.copy( + f"{lora_path}/adapter_model.safetensors", + f"{new_checkpoint_dir}/adapter_model.safetensors", + ) + self._ensure_lora_adapter_config( + new_checkpoint_dir, + source_path=lora_path, + ) + self._latest_step = next_step + else: + await self._publish_dedicated_training_checkpoint(lora_path=lora_path) return llm, lora_path = await self._prepare_for_training() if _config.get("moe_routing_replay_bundle") is not None: diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 2c0cdf10..60dcb2a3 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -12,12 +12,14 @@ - merge_lora_adapter """ +from concurrent.futures import ThreadPoolExecutor import gc import importlib import json import math import os import random +import re import shutil import time from typing import Any, Callable, cast @@ -41,10 +43,21 @@ DEFAULT_JOBS_DIR, DEFAULT_VLLM_WAKE_LOCK_PATH, MegatronJob, + MegatronMergedTrainJob, MegatronSFTTrainingJob, + MegatronSyncJob, MegatronTrainingJob, + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, +) +from art.megatron.lora import ( + LoRA, + MLPExpertsLinearFC1LoRA, + MLPExpertsLinearFC2LoRA, + SelfAttentionLinearProjLoRA, + SelfAttentionLinearQKVLoRA, + apply_lora_adapters, ) -from art.megatron.lora import apply_lora_adapters from art.megatron.merge import load_lora_adapter_state_dict, merge_lora_adapter from art.megatron.model_chunks import ( ModelChunks, @@ -99,6 +112,8 @@ class TrainingRuntime(BaseModel): rank: int world_size: int moe_routing_replay_controller: MoeRoutingReplayController | None = None + merged_weight_transfer_group: Any | None = None + merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = None @field_validator("model") @classmethod @@ -424,7 +439,7 @@ def run_megatron_worker_loop( def run_megatron_rl_job( runtime: TrainingRuntime, - job: MegatronTrainingJob, + job: MegatronTrainingJob | MegatronMergedTrainJob, ) -> None: packed_tensors = None adapter_model = None @@ -504,6 +519,12 @@ def run_megatron_rl_job( lora_path=job.lora_path, optimizer_state_path=job.optimizer_state_path, ) + if isinstance(job, MegatronMergedTrainJob): + _sync_merged_weights_to_vllm( + runtime, + job.merged_weight_transfer, + pause_generation=True, + ) finally: if packed_tensors is not None: del packed_tensors @@ -519,6 +540,29 @@ def run_megatron_rl_job( torch.cuda.empty_cache() +def run_megatron_sync_job( + runtime: TrainingRuntime, + job: MegatronSyncJob, +) -> None: + adapter_model = None + try: + adapter_model = maybe_load_adapter_into_model( + runtime.model, + f"{job.lora_path}/adapter_model.safetensors", + rank=runtime.rank, + ) + _sync_merged_weights_to_vllm( + runtime, + job.merged_weight_transfer, + pause_generation=False, + ) + finally: + if adapter_model is not None: + del adapter_model + gc.collect() + torch.cuda.empty_cache() + + def _flush_param_grads_to_main_grads(model_chunks: ModelChunks) -> None: """Fallback for direct SFT jobs when DDP post-hooks leave grads in param.grad. @@ -665,10 +709,15 @@ def run_megatron_sft_job( def _load_megatron_job(job_path: str, *, supports_sft: bool) -> MegatronJob: with open(job_path, "rb") as handle: job_data = json.loads(handle.read()) - if job_data.get("job_type") == "sft": + job_type = job_data.get("job_type") + if job_type == "sft": if not supports_sft: raise NotImplementedError("SFT jobs are not supported in this worker loop") return MegatronSFTTrainingJob.model_validate(job_data) + if job_type == "merged": + return MegatronMergedTrainJob.model_validate(job_data) + if job_type == "sync": + return MegatronSyncJob.model_validate(job_data) return MegatronTrainingJob.model_validate(job_data) @@ -676,12 +725,17 @@ def _run_megatron_job(runtime: TrainingRuntime, job: MegatronJob) -> None: if isinstance(job, MegatronSFTTrainingJob): run_megatron_sft_job(runtime, job) return + if isinstance(job, MegatronSyncJob): + run_megatron_sync_job(runtime, job) + return run_megatron_rl_job(runtime, job) -def _job_cleanup_path(job: MegatronJob) -> str: +def _job_cleanup_path(job: MegatronJob) -> str | None: if isinstance(job, MegatronSFTTrainingJob): return job.sft_data_dir + if isinstance(job, MegatronSyncJob): + return None return job.disk_packed_tensors["dir"] @@ -759,7 +813,7 @@ def finalize_megatron_job( *, job_path: str | None, log_path: str, - cleanup_path: str, + cleanup_path: str | None, ) -> None: torch.distributed.barrier() # type: ignore[possibly-missing-attribute] if runtime.rank != 0: @@ -767,7 +821,7 @@ def finalize_megatron_job( if job_path is not None and os.path.exists(job_path): os.remove(job_path) - if os.path.exists(cleanup_path): + if cleanup_path is not None and os.path.exists(cleanup_path): shutil.rmtree(cleanup_path) with open(log_path, "a+", encoding="utf-8") as log_file: log_file.write("all done\n") @@ -812,6 +866,20 @@ def iter_modules(model_chunks: ModelChunks) -> Any: yield module +def iter_named_modules(model_chunks: list[MegatronModule]) -> Any: + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + yield module_name, module + + +def _is_language_transformer_layer_name(module_name: str) -> bool: + while module_name.startswith("module."): + module_name = module_name.removeprefix("module.") + return module_name.startswith( + ("decoder.layers.", "language_model.decoder.layers.") + ) + + def load_adapter_into_model( model_chunks: ModelChunks, adapter_model: dict[str, torch.Tensor], @@ -1284,6 +1352,397 @@ def run_training_step( ) +def _is_art_adapter_param_name(name: str) -> bool: + return any( + segment in name + for segment in ( + ".lora.", + ".q_proj_lora.", + ".k_proj_lora.", + ".v_proj_lora.", + ".gate_lora.", + ".up_lora.", + ) + ) + + +def _unwrap_art_wrapper_name(name: str) -> str: + while name.startswith("module."): + name = name[len("module.") :] + for wrapped, unwrapped in ( + (".linear_proj.linear_proj.", ".linear_proj."), + (".linear_qkv.linear_qkv.", ".linear_qkv."), + (".linear_fc1.linear_fc1.", ".linear_fc1."), + (".linear_fc2.linear_fc2.", ".linear_fc2."), + ): + name = name.replace(wrapped, unwrapped) + return name + + +def _mapping_hf_weights_exist(mapping: Any, hf_keys: set[str]) -> bool: + if getattr(mapping, "allow_hf_name_mismatch", False): + return True + hf_param = mapping.hf_param + if isinstance(hf_param, str): + return hf_param in hf_keys + assert isinstance(hf_param, dict) + return all(param in hf_keys for param in hf_param.values()) + + +def _lora_delta(lora: LoRA, expert_idx: int | None = None) -> torch.Tensor: + if lora.A_T.ndim == 3: + assert expert_idx is not None + a_t = lora.A_T[expert_idx] + b_t = lora.B_T[expert_idx] + else: + a_t = lora.A_T + b_t = lora.B_T + return (b_t.T @ a_t.T) * lora.scale + + +def _expert_index_from_hf_name(hf_name: str) -> int: + match = re.search(r"\.experts\.(\d+)\.", hf_name) + assert match is not None + return int(match.group(1)) + + +def _hf_name_has_indexed_expert(hf_name: str) -> bool: + return re.search(r"\.experts\.(\d+)\.", hf_name) is not None + + +def _stack_moe_fc1_deltas(handler: MLPExpertsLinearFC1LoRA) -> torch.Tensor: + return torch.stack( + [ + torch.cat( + [ + _lora_delta(handler.gate_lora, expert_idx), + _lora_delta(handler.up_lora, expert_idx), + ], + dim=0, + ) + for expert_idx in range(handler.gate_lora.num_local_experts) + ], + dim=0, + ) + + +def _stack_moe_fc2_deltas(handler: MLPExpertsLinearFC2LoRA) -> torch.Tensor: + return torch.stack( + [ + _lora_delta(handler.lora, expert_idx) + for expert_idx in range(handler.lora.num_local_experts) + ], + dim=0, + ) + + +def _merge_delta_into_weight( + hf_name: str, + base_weight: torch.Tensor, + delta: torch.Tensor, +) -> torch.Tensor: + delta = delta.to(device=base_weight.device, dtype=base_weight.dtype) + if tuple(base_weight.shape) == tuple(delta.shape): + return base_weight + delta + transposed = delta.transpose(-1, -2) + assert tuple(base_weight.shape) == tuple(transposed.shape), ( + f"{hf_name}: cannot merge delta {tuple(delta.shape)} into {tuple(base_weight.shape)}" + ) + return base_weight + transposed + + +def _build_art_merge_handlers( + model_chunks: list[MegatronModule], +) -> tuple[dict[str, Any], dict[str, Any]]: + exact_handlers: dict[str, Any] = {} + prefix_handlers: dict[str, Any] = {} + for module_name, module in iter_named_modules(model_chunks): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + prefix = f"language_model.decoder.layers.{module.layer_number - 1}" + linear_proj = getattr(module.self_attention, "linear_proj", None) + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): + exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = linear_proj + linear_qkv = getattr(module.self_attention, "linear_qkv", None) + if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): + exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = linear_qkv + experts = getattr(module.mlp, "experts", None) + if experts is None: + continue + if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): + prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = ( + experts.linear_fc1 + ) + if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): + prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = ( + experts.linear_fc2 + ) + return exact_handlers, prefix_handlers + + +def _merge_art_lora_into_hf_weights( + global_param_name: str, + converted_weights_dict: dict[str, torch.Tensor], + *, + exact_handlers: dict[str, Any], + prefix_handlers: dict[str, Any], +) -> dict[str, torch.Tensor]: + handler = exact_handlers.get(global_param_name) + if handler is None: + for prefix, prefix_handler in prefix_handlers.items(): + if global_param_name.startswith(prefix): + handler = prefix_handler + break + if handler is None: + return converted_weights_dict + if isinstance(handler, SelfAttentionLinearProjLoRA): + hf_name, base_weight = next(iter(converted_weights_dict.items())) + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + _lora_delta(handler.lora), + ) + return converted_weights_dict + if isinstance(handler, SelfAttentionLinearQKVLoRA): + deltas = { + "q_proj": _lora_delta(handler.q_proj_lora), + "k_proj": _lora_delta(handler.k_proj_lora), + "v_proj": _lora_delta(handler.v_proj_lora), + } + for hf_name, base_weight in list(converted_weights_dict.items()): + for projection, delta in deltas.items(): + if projection in hf_name: + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + delta, + ) + break + return converted_weights_dict + if isinstance(handler, MLPExpertsLinearFC1LoRA): + for hf_name, base_weight in list(converted_weights_dict.items()): + delta = ( + torch.cat( + [ + _lora_delta( + handler.gate_lora, _expert_index_from_hf_name(hf_name) + ), + _lora_delta( + handler.up_lora, _expert_index_from_hf_name(hf_name) + ), + ], + dim=0, + ) + if _hf_name_has_indexed_expert(hf_name) + else _stack_moe_fc1_deltas(handler) + ) + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + delta, + ) + return converted_weights_dict + assert isinstance(handler, MLPExpertsLinearFC2LoRA) + for hf_name, base_weight in list(converted_weights_dict.items()): + delta = ( + _lora_delta(handler.lora, _expert_index_from_hf_name(hf_name)) + if _hf_name_has_indexed_expert(hf_name) + else _stack_moe_fc2_deltas(handler) + ) + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + delta, + ) + return converted_weights_dict + + +def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: + from itertools import chain + + from megatron.bridge.models.conversion.model_bridge import ( + WeightConversionTask, + _megatron_local_name_to_global, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + persistent_buffers, + ) + + bridge = getattr(runtime.provider, "art_bridge", None) + assert bridge is not None + mapping_registry = bridge._model_bridge.mapping_registry() + hf_source = bridge.hf_pretrained.state.source + hf_keys = set(hf_source.get_all_keys()) + model_config = runtime.model[0].config + tasks: list[Any] = [] + for vp_stage, model in enumerate(runtime.model): + for local_name, _ in chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_name or _is_art_adapter_param_name(local_name): + continue + global_name = _megatron_local_name_to_global( + runtime.model, + model_config, + _unwrap_art_wrapper_name(local_name), + vp_stage, + ) + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if mapping is None or not _mapping_hf_weights_exist(mapping, hf_keys): + continue + local_module, local_weights = get_module_and_param_from_name( + runtime.model, + local_name, + vp_stage, + ) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + tasks.append( + WeightConversionTask( + pp_rank=0, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + ) + return tasks + + +def _iter_merged_vllm_weights(runtime: TrainingRuntime) -> Any: + # vLLM expects HF checkpoint names, but Megatron only has live trainer weights. + # Convert through Bridge here, then merge ART's LoRA deltas into those tensors. + bridge = getattr(runtime.provider, "art_bridge", None) + assert bridge is not None + model_bridge = bridge._model_bridge + hf_state_dict = bridge.hf_pretrained.state + exact_handlers, prefix_handlers = _build_art_merge_handlers(runtime.model) + for task in _build_art_conversion_tasks(runtime): + converted_weights_dict = task.mapping.megatron_to_hf( + task.param_weight, + task.megatron_module, + ) + converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ) + converted_weights_dict = _merge_art_lora_into_hf_weights( + task.global_param_name, + converted_weights_dict, + exact_handlers=exact_handlers, + prefix_handlers=prefix_handlers, + ) + for hf_name, tensor in converted_weights_dict.items(): + yield hf_name, tensor + + +def _ensure_merged_weight_transfer_group( + runtime: TrainingRuntime, + spec: MergedWeightTransferSpec, +) -> None: + assert runtime.rank == 0 + assert runtime.world_size == 1 + if runtime.merged_weight_transfer_init_info == spec.init_info: + assert runtime.merged_weight_transfer_group is not None + return + import httpx + from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine + + def _remote_init() -> None: + response = httpx.post( + f"{spec.vllm_base_url}/init_weight_transfer_engine", + json={"init_info": spec.init_info.model_dump()}, + timeout=300.0, + ) + response.raise_for_status() + + with ThreadPoolExecutor(max_workers=1) as executor: + remote_future = executor.submit(_remote_init) + time.sleep(1.0) + runtime.merged_weight_transfer_group = NCCLWeightTransferEngine.trainer_init( + { + "master_address": spec.init_info.master_address, + "master_port": spec.init_info.master_port, + "world_size": spec.init_info.world_size, + } + ) + remote_future.result() + runtime.merged_weight_transfer_init_info = spec.init_info + + +def _sync_merged_weights_to_vllm( + runtime: TrainingRuntime, + spec: MergedWeightTransferSpec, + *, + pause_generation: bool, +) -> None: + assert runtime.rank == 0 + assert runtime.world_size == 1 + + import httpx + from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine + + _ensure_merged_weight_transfer_group(runtime, spec) + + def _send_weights() -> None: + NCCLWeightTransferEngine.trainer_send_weights( + _iter_merged_vllm_weights(runtime), + {"group": runtime.merged_weight_transfer_group}, + ) + + with httpx.Client() as client: + if pause_generation: + response = client.post( + f"{spec.vllm_base_url}/pause", + params={"mode": "wait"}, + timeout=300.0, + ) + response.raise_for_status() + try: + torch.cuda.synchronize() + names: list[str] = [] + dtype_names: list[str] = [] + shapes: list[list[int]] = [] + for name, tensor in _iter_merged_vllm_weights(runtime): + names.append(name) + dtype_names.append(str(tensor.dtype).removeprefix("torch.")) + shapes.append(list(tensor.shape)) + with ThreadPoolExecutor(max_workers=1) as executor: + send_future = executor.submit(_send_weights) + response = client.post( + f"{spec.vllm_base_url}/update_weights", + json={ + "update_info": { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "is_checkpoint_format": True, + } + }, + timeout=600.0, + ) + response.raise_for_status() + send_future.result() + response = client.post( + f"{spec.vllm_base_url}/art/set_served_model_name", + json={"name": spec.served_model_name}, + timeout=30.0, + ) + response.raise_for_status() + torch.cuda.synchronize() + finally: + if pause_generation: + response = client.post( + f"{spec.vllm_base_url}/resume", + timeout=30.0, + ) + response.raise_for_status() + + def _run_service_loop(runtime: TrainingRuntime) -> None: offload_state = OffloadState() wake_lock_path = os.environ.get( diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index b74b64ed..9782bcb0 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -13,6 +13,10 @@ from art import TrainableModel, types from art.dev.model import InternalModelConfig from art.megatron.backend import MegatronBackend +from art.megatron.jobs import ( + MegatronMergedTrainJob, + MergedWeightTransferInitInfo, +) from art.megatron.service import MegatronService @@ -105,10 +109,12 @@ async def fake_create_subprocess_shell( command: str, cwd: str, env: dict[str, str], + start_new_session: bool, ) -> Any: seen["command"] = command seen["cwd"] = cwd seen["env"] = env + seen["start_new_session"] = start_new_session return pytypes.SimpleNamespace(returncode=None) monkeypatch.setattr( @@ -123,6 +129,7 @@ async def fake_create_subprocess_shell( assert "--nproc_per_node 2" in seen["command"] assert seen["env"]["CUDA_VISIBLE_DEVICES"] == "0,1" assert seen["env"]["MODEL_IDENTIFIER"] == "Qwen/Qwen3-30B-A3B-Instruct-2507" + assert seen["start_new_session"] is True @pytest.mark.asyncio @@ -200,3 +207,158 @@ async def test_megatron_service_register_lora_for_step_dedicated_reloads_adapter await service.register_lora_for_step(3, "/tmp/checkpoints/3") assert seen == [("/tmp/checkpoints/3", 3)] + + +@pytest.mark.asyncio +async def test_megatron_service_start_openai_server_merged_syncs_step_zero( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0000" + checkpoint_dir.mkdir(parents=True) + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + "art.megatron.service.get_last_checkpoint_dir", + lambda _output_dir: str(checkpoint_dir), + ) + monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) + monkeypatch.setattr( + service, + "_start_vllm_subprocess", + lambda lora_path, port, config: asyncio.sleep(0, result=("127.0.0.1", port)), + ) + monkeypatch.setattr( + service, + "_sync_dedicated_merged_weights", + lambda *, lora_path, step: calls.append((lora_path, step)) or asyncio.sleep(0), + ) + + location = await service.start_openai_server({"server_args": {"port": 8123}}) + + assert location == ("127.0.0.1", 8123) + assert calls == [(str(checkpoint_dir), 0)] + + +@pytest.mark.asyncio +async def test_megatron_service_register_lora_for_step_merged_sets_served_name( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + calls: list[int] = [] + + monkeypatch.setattr( + service, + "_set_served_model_name", + lambda step: calls.append(step) or asyncio.sleep(0), + ) + + await service.register_lora_for_step(3, "/tmp/checkpoints/3") + + assert calls == [3] + + +@pytest.mark.asyncio +async def test_megatron_service_train_merged_writes_merged_job_and_does_not_reload_adapter( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0000" + checkpoint_dir.mkdir(parents=True) + adapter_path = checkpoint_dir / "adapter_model.safetensors" + adapter_path.write_bytes(b"adapter") + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + events: list[Any] = [] + + monkeypatch.setattr( + service, + "_ensure_megatron_running", + lambda: events.append("ensure") or asyncio.sleep(0), + ) + monkeypatch.setattr( + service, "_resolve_active_lora_path", lambda: str(checkpoint_dir) + ) + monkeypatch.setattr( + service, + "_init_merged_weight_transfer", + lambda: events.append("init") or asyncio.sleep(0), + ) + monkeypatch.setattr( + "art.megatron.service.create_megatron_job_paths", + lambda *, jobs_dir, training_log_dir: ("/tmp/job.json", "/tmp/log.jsonl"), + ) + monkeypatch.setattr( + "art.megatron.service.write_megatron_job", + lambda job, *, job_path: events.append(job), + ) + + async def fake_stream_megatron_job(job: Any, *, job_path: str): + events.append(("stream", job_path, job.lora_path)) + yield {"loss": 1.0} + + monkeypatch.setattr( + "art.megatron.service.stream_megatron_job", + fake_stream_megatron_job, + ) + monkeypatch.setattr( + service, + "_ensure_lora_adapter_config", + lambda _path, source_path=None: None, + ) + monkeypatch.setattr( + service, + "_reload_adapter", + lambda checkpoint_dir, step: (_ for _ in ()).throw( + AssertionError("merged mode should not hot-reload a LoRA adapter") + ), + ) + service._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=1234, + rank_offset=1, + world_size=2, + ) + + results = [] + async for result in service.train( + {"dir": "/tmp/tensors", "num_sequences": 1, "sequence_length": 16}, + types.TrainConfig(learning_rate=5e-5), + {}, + ): + results.append(result) + + assert results == [{"loss": 1.0}] + assert events[0:2] == ["ensure", "init"] + job = events[2] + assert isinstance(job, MegatronMergedTrainJob) + assert job.merged_weight_transfer.served_model_name == "megatron-merged@1" + assert events[3] == ("stream", "/tmp/job.json", str(checkpoint_dir)) From e8bece4a5c2032060d082cf8c789e4881c615de5 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Fri, 10 Apr 2026 16:11:20 -0700 Subject: [PATCH 3/9] fix: restore merged sync adapter loading --- src/art/megatron/train.py | 19 +++++++++++++++++++ tests/unit/test_megatron_dedicated.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 60dcb2a3..d84bc4de 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -895,6 +895,25 @@ def load_adapter_into_model( optimizer.reload_model_params() +def maybe_load_adapter_into_model( + model_chunks: ModelChunks, + adapter_model_path: str, + optimizer: Any | None = None, + *, + rank: int, +) -> dict[str, torch.Tensor]: + if not os.path.exists(adapter_model_path): + print0(rank, "No adapter model found at", adapter_model_path) + return {} + print0(rank, "Loading adapter model from", adapter_model_path) + with safe_open(adapter_model_path, framework="pt") as adapter_file: + adapter_model = { + key: adapter_file.get_tensor(key) for key in adapter_file.keys() + } + load_adapter_into_model(model_chunks, adapter_model, optimizer) + return adapter_model + + def collect_sharded_lora_state( model_chunks: ModelChunks, adapter_model: dict[str, torch.Tensor], diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index 9782bcb0..5c7803fb 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -124,8 +124,8 @@ async def fake_create_subprocess_shell( await service._ensure_megatron_running() - assert shlex.quote(sys.executable) in seen["command"] - assert "torch.distributed.run" in seen["command"] + assert "uv run --project" in seen["command"] + assert "torchrun" in seen["command"] assert "--nproc_per_node 2" in seen["command"] assert seen["env"]["CUDA_VISIBLE_DEVICES"] == "0,1" assert seen["env"]["MODEL_IDENTIFIER"] == "Qwen/Qwen3-30B-A3B-Instruct-2507" From 533817865615d56ee9fc093ee9c5a377e701aa41 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Mon, 13 Apr 2026 12:01:44 -0700 Subject: [PATCH 4/9] fix: bootstrap shared LoRA at step zero --- src/art/megatron/service.py | 2 +- tests/unit/test_megatron_dedicated.py | 47 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 485f290b..e51fc8c7 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -273,7 +273,7 @@ def _resolve_active_lora_path(self) -> str: self._latest_step = 0 else: self._latest_step = get_step_from_dir(self.output_dir) - if self.rollout_weights_mode == "lora": + if self.is_dedicated or self.rollout_weights_mode == "lora": self._ensure_identity_lora(lora_path) self._ensure_lora_adapter_config(lora_path) return lora_path diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index 5c7803fb..3387d252 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -227,11 +227,17 @@ async def test_megatron_service_start_openai_server_merged_syncs_step_zero( output_dir=str(tmp_path), ) calls: list[tuple[str, int]] = [] + ensured_identity_paths: list[str] = [] monkeypatch.setattr( "art.megatron.service.get_last_checkpoint_dir", lambda _output_dir: str(checkpoint_dir), ) + monkeypatch.setattr( + service, + "_ensure_identity_lora", + lambda path: ensured_identity_paths.append(path), + ) monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) monkeypatch.setattr( service, @@ -247,9 +253,50 @@ async def test_megatron_service_start_openai_server_merged_syncs_step_zero( location = await service.start_openai_server({"server_args": {"port": 8123}}) assert location == ("127.0.0.1", 8123) + assert ensured_identity_paths == [str(checkpoint_dir)] assert calls == [(str(checkpoint_dir), 0)] +@pytest.mark.asyncio +async def test_megatron_service_start_openai_server_shared_lora_bootstraps_step_zero( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0000" + checkpoint_dir.mkdir(parents=True) + service = MegatronService( + model_name="megatron-shared", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig(), + output_dir=str(tmp_path), + ) + ensured_identity_paths: list[str] = [] + + monkeypatch.setattr( + "art.megatron.service.get_last_checkpoint_dir", + lambda _output_dir: str(checkpoint_dir), + ) + monkeypatch.setattr( + service, + "_ensure_identity_lora", + lambda path: ensured_identity_paths.append(path), + ) + monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) + monkeypatch.setattr( + "art.megatron.service.dev.get_openai_server_config", + lambda **_kwargs: {"server_args": {"port": 8123}, "engine_args": {}}, + ) + monkeypatch.setattr( + "art.megatron.service.openai_server_task", + lambda **_kwargs: asyncio.sleep(0), + ) + + location = await service.start_openai_server({"server_args": {"port": 8123}}) + + assert location == ("0.0.0.0", 8123) + assert ensured_identity_paths == [str(checkpoint_dir)] + + @pytest.mark.asyncio async def test_megatron_service_register_lora_for_step_merged_sets_served_name( tmp_path: Path, From a33458baacb6e7bde5eea559687125c58d50342d Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Mon, 13 Apr 2026 16:26:29 -0700 Subject: [PATCH 5/9] test: Cover dedicated Megatron merged startup reset --- tests/unit/test_megatron_dedicated.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index 3387d252..0a75a29f 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -226,7 +226,7 @@ async def test_megatron_service_start_openai_server_merged_syncs_step_zero( ), output_dir=str(tmp_path), ) - calls: list[tuple[str, int]] = [] + calls: list[object] = [] ensured_identity_paths: list[str] = [] monkeypatch.setattr( @@ -244,6 +244,7 @@ async def test_megatron_service_start_openai_server_merged_syncs_step_zero( "_start_vllm_subprocess", lambda lora_path, port, config: asyncio.sleep(0, result=("127.0.0.1", port)), ) + monkeypatch.setattr(service, "_clear_pending_jobs", lambda: calls.append("clear")) monkeypatch.setattr( service, "_sync_dedicated_merged_weights", @@ -254,7 +255,7 @@ async def test_megatron_service_start_openai_server_merged_syncs_step_zero( assert location == ("127.0.0.1", 8123) assert ensured_identity_paths == [str(checkpoint_dir)] - assert calls == [(str(checkpoint_dir), 0)] + assert calls == ["clear", (str(checkpoint_dir), 0)] @pytest.mark.asyncio From ab4e7500876d72cc3ba58a6f3d68eee5069011df Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Mon, 13 Apr 2026 16:26:40 -0700 Subject: [PATCH 6/9] fix: Clear stale dedicated merged jobs on startup --- src/art/megatron/service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index e51fc8c7..058fcabd 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -618,6 +618,7 @@ async def start_openai_server( port = (config or {}).get("server_args", {}).get("port", 8000) location = await self._start_vllm_subprocess(lora_path, port, config) if self.rollout_weights_mode == "merged": + self._clear_pending_jobs() await self._sync_dedicated_merged_weights( lora_path=lora_path, step=self._latest_step, From 6eddd9675ccfd4467c78b6342cf79023708d9638 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Mon, 13 Apr 2026 16:26:58 -0700 Subject: [PATCH 7/9] test: Cover compiled Megatron wrapper names --- tests/unit/test_megatron_dedicated.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index 0a75a29f..cc0a808b 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -18,6 +18,7 @@ MergedWeightTransferInitInfo, ) from art.megatron.service import MegatronService +from art.megatron.train import _unwrap_art_wrapper_name @pytest.mark.asyncio @@ -132,6 +133,21 @@ async def fake_create_subprocess_shell( assert seen["start_new_session"] is True +def test_unwrap_art_wrapper_name_strips_compiled_wrapper_segments() -> None: + assert ( + _unwrap_art_wrapper_name( + "module.module.decoder.layers.0._orig_mod.self_attention.linear_proj.linear_proj.weight" + ) + == "decoder.layers.0.self_attention.linear_proj.weight" + ) + assert ( + _unwrap_art_wrapper_name( + "module.module.decoder.layers.0._orig_mod.mlp.experts.linear_fc1.linear_fc1.weight7" + ) + == "decoder.layers.0.mlp.experts.linear_fc1.weight7" + ) + + @pytest.mark.asyncio async def test_megatron_service_start_openai_server_dedicated_starts_subprocess( tmp_path: Path, From f97f79cf7a7b6e68022e6407be6c79a789a8ef8c Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Mon, 13 Apr 2026 16:27:14 -0700 Subject: [PATCH 8/9] fix: Restore Megatron dedicated merged sync --- src/art/megatron/train.py | 59 +++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index d84bc4de..409e4b76 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -1388,6 +1388,12 @@ def _is_art_adapter_param_name(name: str) -> bool: def _unwrap_art_wrapper_name(name: str) -> str: while name.startswith("module."): name = name[len("module.") :] + while name.startswith("_orig_mod."): + name = name[len("_orig_mod.") :] + while "._orig_mod." in name: + name = name.replace("._orig_mod.", ".") + if name.endswith("._orig_mod"): + name = name[: -len("._orig_mod")] for wrapped, unwrapped in ( (".linear_proj.linear_proj.", ".linear_proj."), (".linear_qkv.linear_qkv.", ".linear_qkv."), @@ -1480,24 +1486,35 @@ def _build_art_merge_handlers( continue if not _is_language_transformer_layer_name(module_name): continue - prefix = f"language_model.decoder.layers.{module.layer_number - 1}" + prefixes = ( + f"decoder.layers.{module.layer_number - 1}", + f"language_model.decoder.layers.{module.layer_number - 1}", + ) linear_proj = getattr(module.self_attention, "linear_proj", None) if isinstance(linear_proj, SelfAttentionLinearProjLoRA): - exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = linear_proj + for prefix in prefixes: + exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = ( + linear_proj + ) linear_qkv = getattr(module.self_attention, "linear_qkv", None) if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): - exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = linear_qkv + for prefix in prefixes: + exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = ( + linear_qkv + ) experts = getattr(module.mlp, "experts", None) if experts is None: continue if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): - prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = ( - experts.linear_fc1 - ) + for prefix in prefixes: + prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = ( + experts.linear_fc1 + ) if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): - prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = ( - experts.linear_fc2 - ) + for prefix in prefixes: + prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = ( + experts.linear_fc2 + ) return exact_handlers, prefix_handlers @@ -1542,21 +1559,15 @@ def _merge_art_lora_into_hf_weights( return converted_weights_dict if isinstance(handler, MLPExpertsLinearFC1LoRA): for hf_name, base_weight in list(converted_weights_dict.items()): - delta = ( - torch.cat( - [ - _lora_delta( - handler.gate_lora, _expert_index_from_hf_name(hf_name) - ), - _lora_delta( - handler.up_lora, _expert_index_from_hf_name(hf_name) - ), - ], - dim=0, - ) - if _hf_name_has_indexed_expert(hf_name) - else _stack_moe_fc1_deltas(handler) - ) + if _hf_name_has_indexed_expert(hf_name): + expert_idx = _expert_index_from_hf_name(hf_name) + if ".gate_proj." in hf_name: + delta = _lora_delta(handler.gate_lora, expert_idx) + else: + assert ".up_proj." in hf_name, hf_name + delta = _lora_delta(handler.up_lora, expert_idx) + else: + delta = _stack_moe_fc1_deltas(handler) converted_weights_dict[hf_name] = _merge_delta_into_weight( hf_name, base_weight, From f55698a42747128385cc199d25c97722b083bf10 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Fri, 17 Apr 2026 19:46:32 -0700 Subject: [PATCH 9/9] fix: Restore Megatron bridge typing --- src/art/megatron/train.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 409e4b76..f5ab2d8a 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -875,9 +875,7 @@ def iter_named_modules(model_chunks: list[MegatronModule]) -> Any: def _is_language_transformer_layer_name(module_name: str) -> bool: while module_name.startswith("module."): module_name = module_name.removeprefix("module.") - return module_name.startswith( - ("decoder.layers.", "language_model.decoder.layers.") - ) + return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) def load_adapter_into_model( @@ -1606,14 +1604,15 @@ def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: mapping_registry = bridge._model_bridge.mapping_registry() hf_source = bridge.hf_pretrained.state.source hf_keys = set(hf_source.get_all_keys()) - model_config = runtime.model[0].config + megatron_chunks = as_megatron_api_chunks(runtime.model) + model_config = megatron_chunks[0].config tasks: list[Any] = [] - for vp_stage, model in enumerate(runtime.model): + for vp_stage, model in enumerate(megatron_chunks): for local_name, _ in chain(model.named_parameters(), persistent_buffers(model)): if "_extra_state" in local_name or _is_art_adapter_param_name(local_name): continue global_name = _megatron_local_name_to_global( - runtime.model, + megatron_chunks, model_config, _unwrap_art_wrapper_name(local_name), vp_stage, @@ -1621,11 +1620,13 @@ def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: mapping = mapping_registry.megatron_to_hf_lookup(global_name) if mapping is None or not _mapping_hf_weights_exist(mapping, hf_keys): continue - local_module, local_weights = get_module_and_param_from_name( - runtime.model, + module_and_param = get_module_and_param_from_name( + megatron_chunks, local_name, vp_stage, ) + local_module = module_and_param[0] + local_weights = module_and_param[1] if local_module is not None and not hasattr(local_module, "config"): setattr(local_module, "config", model_config) tasks.append( @@ -1649,7 +1650,8 @@ def _iter_merged_vllm_weights(runtime: TrainingRuntime) -> Any: assert bridge is not None model_bridge = bridge._model_bridge hf_state_dict = bridge.hf_pretrained.state - exact_handlers, prefix_handlers = _build_art_merge_handlers(runtime.model) + megatron_chunks = as_megatron_api_chunks(runtime.model) + exact_handlers, prefix_handlers = _build_art_merge_handlers(megatron_chunks) for task in _build_art_conversion_tasks(runtime): converted_weights_dict = task.mapping.megatron_to_hf( task.param_weight,