Skip to content

[WIP] Feat/tensor colocated weight sync#1164

Closed
HT-Yuan wants to merge 8 commits intoinclusionAI:mainfrom
HT-Yuan:feat/tensor-colocated-weight-sync
Closed

[WIP] Feat/tensor colocated weight sync#1164
HT-Yuan wants to merge 8 commits intoinclusionAI:mainfrom
HT-Yuan:feat/tensor-colocated-weight-sync

Conversation

@HT-Yuan
Copy link
Copy Markdown
Contributor

@HT-Yuan HT-Yuan commented Apr 11, 2026

Description

Add backend-aware dispatching for colocated tensor weight synchronization, enabling vLLM's native IPCWeightTransferEngine as an alternative to the existing SGLang FlattenedTensorBucket + MultiprocessingSerializer path.

Previously, the tensor weight update path in FSDPEngine was hardcoded to SGLang's serialization format. This PR introduces a tensor_target_backend parameter that flows from rl_trainer.pytrain_controller.pyfsdp_engine.py, allowing the engine to dispatch to the correct transport mechanism based on the rollout backend.

Key changes

  1. vllm_remote.pyVLLMBackend gains send_tensor_weight_update() which delegates to vLLM's IPCWeightTransferEngine.trainer_send_weights(); RemotevLLMEngine gains update_weights_from_tensor().
  2. fsdp_engine.py_flush_colocated_tensor_bucket() refactored to dispatch based on supports_direct_tensor_weight_update; SGLang logic extracted to _flush_sglang_tensor_bucket(); added _make_tensor_backend() factory.
  3. remote_inf_engine.pyRemoteInfBackendProtocol gains build_tensor_weight_update_requests() method declaration.
  4. engine_api.py / train_controller.py / megatron_engine.py / archon_engine.pyconnect_engine() signature extended with tensor_target_backend: str | None for interface alignment.
  5. rl_trainer.py — passes self.rollout_alloc.backend as tensor_target_backend.

Related Issue

Fixes #(issue)

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A — The tensor_target_backend parameter is optional with a default of None (falls back to "sglang"), so existing callers are unaffected.

Additional Context

Architecture

@HT-Yuan HT-Yuan marked this pull request as draft April 11, 2026 11:04
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new 'tensor' weight update mode for colocated training and inference, utilizing CUDA IPC for efficient transfers. It implements a two-phase update process in the FSDP engine—staging parameters to CPU pinned memory before transferring them to the inference engine—and adds backend support for both SGLang and vLLM. Review feedback highlights opportunities to reduce code duplication in parameter selection and request building, improve network efficiency by reusing HTTP sessions across buckets, and enhance performance by removing expensive and unnecessary GPU cache clearing calls during the update loop.

Comment on lines +1403 to +1410
if self.config.use_lora:
param_iterator = (
(name, param)
for name, param in self._get_model_name_parameters(meta)
if param.requires_grad
)
else:
param_iterator = self._get_model_name_parameters(meta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for selecting parameters to update (filtering by requires_grad when use_lora is enabled) is duplicated from _update_weights_from_distributed. Consider refactoring this into a helper method (e.g., _get_trainable_parameters) to ensure consistency and reduce code duplication.

Comment on lines +1451 to +1475
with tms_context:
if current_platform.device_type == "cuda" and torch.cuda.is_available():
current_platform.set_device(int(os.environ.get("LOCAL_RANK", 0)))

bucket: list[tuple[str, torch.Tensor]] = []
bucket_bytes = 0

for name, cpu_tensor in staged:
tensor_bytes = cpu_tensor.numel() * cpu_tensor.element_size()

if bucket_bytes + tensor_bytes > weight_chunked_mem_size and bucket:
self._flush_colocated_tensor_bucket(bucket, meta)
bucket = []
bucket_bytes = 0

gpu_tensor = cpu_tensor.to(
current_platform.current_device(), non_blocking=False
)
bucket.append((name, gpu_tensor))
bucket_bytes += tensor_bytes

if bucket:
self._flush_colocated_tensor_bucket(bucket, meta)
finally:
staged.clear()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _apply_colocated_tensor_weights, consider creating a single aiohttp.ClientSession and passing it down to the flush methods. Currently, a new session (and connection pool) is created for every bucket in _send_tensor_to_servers, which is inefficient when processing many buckets during a weight update.

Comment on lines +1511 to +1513
if current_platform.device_type == "cuda" and torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch.cuda.empty_cache() is an expensive operation that synchronizes the GPU and can significantly degrade performance, especially when called repeatedly in a loop (as it is here via _apply_colocated_tensor_weights). Since torch.cuda.ipc_collect() is already called to release IPC handles, consider removing empty_cache() or moving it outside the loop to avoid unnecessary overhead.

Suggested change
if current_platform.device_type == "cuda" and torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
if current_platform.device_type == "cuda" and torch.cuda.is_available():
torch.cuda.ipc_collect()

Comment thread areal/engine/fsdp_engine.py Outdated
Comment on lines +1569 to +1611
def _send_tensor_to_servers(
self,
serialized_named_tensors: list[str],
addresses: list[str],
weight_version: str | None = None,
) -> None:
"""Send serialized tensor data to SGLang servers via HTTP."""
import asyncio

import aiohttp
import uvloop

from areal.infra.utils.http import arequest_with_retry, get_default_connector

payload: dict[str, Any] = {
"serialized_named_tensors": serialized_named_tensors,
"load_format": "flattened_bucket",
"flush_cache": False,
}
if weight_version is not None:
payload["weight_version"] = weight_version

async def _fn():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=600),
read_bufsize=1024 * 1024 * 10,
connector=get_default_connector(),
) as session:
jobs = [
arequest_with_retry(
session=session,
addr=addr,
endpoint="/update_weights_from_tensor",
payload=payload,
method="POST",
max_retries=1,
timeout=600,
)
for addr in addresses
]
await asyncio.gather(*jobs)

uvloop.run(_fn())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _send_tensor_to_servers method and the payload building logic in _flush_sglang_tensor_bucket (lines 1545-1550) duplicate logic already present in areal.infra.remote_inf_engine._update_weights_from_tensor and the backend's build_tensor_weight_update_requests. Consider refactoring to use the shared helper from remote_inf_engine to improve maintainability and ensure consistency across different weight update paths.

@HT-Yuan
Copy link
Copy Markdown
Contributor Author

HT-Yuan commented Apr 23, 2026

@garrett4wade
As mentioned earlier, regarding #1214 and #1157 , this PR should probably be closed.At present, CUDA IPC communication cannot be performed without enabling TMS. In my opinion, it makes little sense to temporarily store training weights on the CPU as Slime does.What do you think?

@garrett4wade
Copy link
Copy Markdown
Collaborator

@HT-Yuan I think the high-level decision is adopting the CUDA IPC and P2P primitives in awex, and with the frontend developed in #1214 . We can add sglang extensions within AReaL to allow customized IPC communication endpoints. The CPU serialization approach should be abandoned.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days.

Please add a comment or push new commits to keep it active.

Thank you for your contribution!

@github-actions github-actions Bot added the stale label May 8, 2026
@HT-Yuan HT-Yuan closed this May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants