[WIP] Feat/tensor colocated weight sync#1164
[WIP] Feat/tensor colocated weight sync#1164HT-Yuan wants to merge 8 commits intoinclusionAI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new 'tensor' weight update mode for colocated training and inference, utilizing CUDA IPC for efficient transfers. It implements a two-phase update process in the FSDP engine—staging parameters to CPU pinned memory before transferring them to the inference engine—and adds backend support for both SGLang and vLLM. Review feedback highlights opportunities to reduce code duplication in parameter selection and request building, improve network efficiency by reusing HTTP sessions across buckets, and enhance performance by removing expensive and unnecessary GPU cache clearing calls during the update loop.
| if self.config.use_lora: | ||
| param_iterator = ( | ||
| (name, param) | ||
| for name, param in self._get_model_name_parameters(meta) | ||
| if param.requires_grad | ||
| ) | ||
| else: | ||
| param_iterator = self._get_model_name_parameters(meta) |
There was a problem hiding this comment.
| with tms_context: | ||
| if current_platform.device_type == "cuda" and torch.cuda.is_available(): | ||
| current_platform.set_device(int(os.environ.get("LOCAL_RANK", 0))) | ||
|
|
||
| bucket: list[tuple[str, torch.Tensor]] = [] | ||
| bucket_bytes = 0 | ||
|
|
||
| for name, cpu_tensor in staged: | ||
| tensor_bytes = cpu_tensor.numel() * cpu_tensor.element_size() | ||
|
|
||
| if bucket_bytes + tensor_bytes > weight_chunked_mem_size and bucket: | ||
| self._flush_colocated_tensor_bucket(bucket, meta) | ||
| bucket = [] | ||
| bucket_bytes = 0 | ||
|
|
||
| gpu_tensor = cpu_tensor.to( | ||
| current_platform.current_device(), non_blocking=False | ||
| ) | ||
| bucket.append((name, gpu_tensor)) | ||
| bucket_bytes += tensor_bytes | ||
|
|
||
| if bucket: | ||
| self._flush_colocated_tensor_bucket(bucket, meta) | ||
| finally: | ||
| staged.clear() |
There was a problem hiding this comment.
In _apply_colocated_tensor_weights, consider creating a single aiohttp.ClientSession and passing it down to the flush methods. Currently, a new session (and connection pool) is created for every bucket in _send_tensor_to_servers, which is inefficient when processing many buckets during a weight update.
| if current_platform.device_type == "cuda" and torch.cuda.is_available(): | ||
| torch.cuda.ipc_collect() | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
torch.cuda.empty_cache() is an expensive operation that synchronizes the GPU and can significantly degrade performance, especially when called repeatedly in a loop (as it is here via _apply_colocated_tensor_weights). Since torch.cuda.ipc_collect() is already called to release IPC handles, consider removing empty_cache() or moving it outside the loop to avoid unnecessary overhead.
| if current_platform.device_type == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.ipc_collect() | |
| torch.cuda.empty_cache() | |
| if current_platform.device_type == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.ipc_collect() |
| def _send_tensor_to_servers( | ||
| self, | ||
| serialized_named_tensors: list[str], | ||
| addresses: list[str], | ||
| weight_version: str | None = None, | ||
| ) -> None: | ||
| """Send serialized tensor data to SGLang servers via HTTP.""" | ||
| import asyncio | ||
|
|
||
| import aiohttp | ||
| import uvloop | ||
|
|
||
| from areal.infra.utils.http import arequest_with_retry, get_default_connector | ||
|
|
||
| payload: dict[str, Any] = { | ||
| "serialized_named_tensors": serialized_named_tensors, | ||
| "load_format": "flattened_bucket", | ||
| "flush_cache": False, | ||
| } | ||
| if weight_version is not None: | ||
| payload["weight_version"] = weight_version | ||
|
|
||
| async def _fn(): | ||
| async with aiohttp.ClientSession( | ||
| timeout=aiohttp.ClientTimeout(total=600), | ||
| read_bufsize=1024 * 1024 * 10, | ||
| connector=get_default_connector(), | ||
| ) as session: | ||
| jobs = [ | ||
| arequest_with_retry( | ||
| session=session, | ||
| addr=addr, | ||
| endpoint="/update_weights_from_tensor", | ||
| payload=payload, | ||
| method="POST", | ||
| max_retries=1, | ||
| timeout=600, | ||
| ) | ||
| for addr in addresses | ||
| ] | ||
| await asyncio.gather(*jobs) | ||
|
|
||
| uvloop.run(_fn()) |
There was a problem hiding this comment.
The _send_tensor_to_servers method and the payload building logic in _flush_sglang_tensor_bucket (lines 1545-1550) duplicate logic already present in areal.infra.remote_inf_engine._update_weights_from_tensor and the backend's build_tensor_weight_update_requests. Consider refactoring to use the shared helper from remote_inf_engine to improve maintainability and ensure consistency across different weight update paths.
|
@garrett4wade |
|
This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days. Please add a comment or push new commits to keep it active. Thank you for your contribution! |
Description
Add backend-aware dispatching for colocated tensor weight synchronization, enabling vLLM's native
IPCWeightTransferEngineas an alternative to the existing SGLangFlattenedTensorBucket + MultiprocessingSerializerpath.Previously, the tensor weight update path in
FSDPEnginewas hardcoded to SGLang's serialization format. This PR introduces atensor_target_backendparameter that flows fromrl_trainer.py→train_controller.py→fsdp_engine.py, allowing the engine to dispatch to the correct transport mechanism based on the rollout backend.Key changes
vllm_remote.py—VLLMBackendgainssend_tensor_weight_update()which delegates to vLLM'sIPCWeightTransferEngine.trainer_send_weights();RemotevLLMEnginegainsupdate_weights_from_tensor().fsdp_engine.py—_flush_colocated_tensor_bucket()refactored to dispatch based onsupports_direct_tensor_weight_update; SGLang logic extracted to_flush_sglang_tensor_bucket(); added_make_tensor_backend()factory.remote_inf_engine.py—RemoteInfBackendProtocolgainsbuild_tensor_weight_update_requests()method declaration.engine_api.py/train_controller.py/megatron_engine.py/archon_engine.py—connect_engine()signature extended withtensor_target_backend: str | Nonefor interface alignment.rl_trainer.py— passesself.rollout_alloc.backendastensor_target_backend.Related Issue
Fixes #(issue)
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A — The
tensor_target_backendparameter is optional with a default ofNone(falls back to"sglang"), so existing callers are unaffected.Additional Context
Architecture