Skip to content

Commit db87da7

Browse files
committed
feat: [NemoRL] Introduce WeightSynchronizer ABC with IPC/HTTP/NCCL transports
Signed-off-by: Saurabh Mishra <sauramishra@nvidia.com> Co-authored-by: Saurabh Mishra <sauramishra@nvidia.com> [NemoRL]: fix: Address review feedback on WeightSynchronizer ABC Signed-off-by: Saurabh Mishra <sauramishra@nvidia.com>
1 parent 30afecc commit db87da7

9 files changed

Lines changed: 1051 additions & 0 deletions

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Constants for generation backend names.
16+
17+
These should be used instead of raw string literals when checking or
18+
comparing backend names in config values.
19+
"""
20+
21+
VLLM_BACKEND = "vllm"
22+
SGLANG_BACKEND = "sglang"
23+
MEGATRON_BACKEND = "megatron"

nemo_rl/weight_sync/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo_rl.weight_sync.factory import create_weight_synchronizer
16+
from nemo_rl.weight_sync.interfaces import WeightSynchronizer
17+
18+
__all__ = [
19+
"WeightSynchronizer",
20+
"create_weight_synchronizer",
21+
]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NCCL collective weight synchronizer for non-colocated deployments.
16+
17+
Handles weight transfer between policy and generation workers running on
18+
separate GPU clusters using NCCL collective communication. The policy
19+
broadcasts its weights, and generation workers receive them via the
20+
established NCCL process group.
21+
22+
Lifecycle per sync:
23+
1. policy.broadcast_weights_for_collective() -- send via NCCL
24+
generation.update_weights_from_collective() -- receive via NCCL
25+
2. Verify transfer success
26+
27+
No offload/restore steps are needed since policy and generation run on
28+
separate GPUs with dedicated memory.
29+
"""
30+
31+
from contextlib import nullcontext
32+
from typing import Any, Optional
33+
34+
import ray
35+
36+
from nemo_rl.utils.timer import Timer
37+
from nemo_rl.weight_sync.interfaces import WeightSynchronizer
38+
39+
40+
class CollectiveWeightSynchronizer(WeightSynchronizer):
41+
"""Weight synchronizer using NCCL collectives for non-colocated deployments.
42+
43+
Policy and generation workers run on separate GPU clusters. Weights are
44+
synchronized via NCCL broadcast over a pre-established process group.
45+
46+
Args:
47+
policy: Policy object implementing ColocatablePolicyInterface.
48+
generation: Generation object implementing GenerationInterface.
49+
train_cluster: RayVirtualCluster for the training workers, used to
50+
obtain the master address/port and world size for collective init.
51+
inference_cluster: RayVirtualCluster for the inference workers.
52+
"""
53+
54+
def __init__(
55+
self,
56+
policy: Any,
57+
generation: Any,
58+
train_cluster: Any,
59+
inference_cluster: Any,
60+
):
61+
self._policy = policy
62+
self._generation = generation
63+
self._train_cluster = train_cluster
64+
self._inference_cluster = inference_cluster
65+
self._stale = True
66+
67+
def sync_weights(
68+
self,
69+
*,
70+
timer: Optional[Timer] = None,
71+
kv_scales: Optional[dict[str, float]] = None,
72+
) -> None:
73+
timer_context = (
74+
timer.time("prepare_for_generation/transfer_and_update_weights")
75+
if timer is not None
76+
else nullcontext()
77+
)
78+
with timer_context:
79+
futures_train = self._policy.broadcast_weights_for_collective(
80+
kv_scales=kv_scales
81+
)
82+
futures_inference = (
83+
self._generation.update_weights_from_collective()
84+
)
85+
86+
ray.get(futures_train)
87+
results = ray.get(futures_inference)
88+
update_success = all(
89+
result for result in results if result is not None
90+
)
91+
92+
if not update_success:
93+
raise RuntimeError(
94+
"Weight transfer failed during NCCL collective sync. "
95+
"This often indicates an issue with the NCCL process group "
96+
"or the generation backend worker."
97+
)
98+
99+
self._stale = False
100+
101+
@property
102+
def is_stale(self) -> bool:
103+
return self._stale
104+
105+
def mark_stale(self) -> None:
106+
self._stale = True
107+
108+
def init_communicator(self) -> None:
109+
# prepare_refit_info is called before init_collective. This matches
110+
# distillation.py ordering. Neither call depends on the other today,
111+
# but we document this as the canonical ordering for future reference.
112+
state_dict_info = self._policy.prepare_refit_info()
113+
self._generation.prepare_refit_info(state_dict_info)
114+
115+
ip, port = self._train_cluster.get_master_address_and_port()
116+
train_world_size = self._train_cluster.world_size()
117+
inference_world_size = self._inference_cluster.world_size()
118+
world_size = train_world_size + inference_world_size
119+
120+
futures_train = self._policy.init_collective(
121+
ip, port, world_size, train_world_size=train_world_size
122+
)
123+
futures_inference = self._generation.init_collective(
124+
ip, port, world_size, train_world_size=train_world_size
125+
)
126+
ray.get(futures_train + futures_inference)
127+
128+
def shutdown(self) -> None:
129+
# The NCCL process group lifecycle is managed by Ray actor teardown.
130+
# Explicit destroy_process_group() is not needed here because the
131+
# workers that own the group are destroyed when the cluster shuts down.
132+
pass

nemo_rl/weight_sync/factory.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Factory for creating WeightSynchronizer instances.
16+
17+
Selects the appropriate weight synchronizer based on the deployment
18+
topology (colocated vs. non-colocated) and the generation backend
19+
(vLLM uses IPC/ZMQ, SGLang uses HTTP, non-colocated uses NCCL).
20+
"""
21+
22+
from typing import Any, Optional
23+
24+
from nemo_rl.models.generation.constants import (
25+
MEGATRON_BACKEND,
26+
SGLANG_BACKEND,
27+
VLLM_BACKEND,
28+
)
29+
from nemo_rl.weight_sync.interfaces import WeightSynchronizer
30+
31+
32+
def create_weight_synchronizer(
33+
policy: Any,
34+
generation: Any,
35+
generation_backend: str,
36+
colocated: bool,
37+
train_cluster: Optional[Any] = None,
38+
inference_cluster: Optional[Any] = None,
39+
refit_buffer_size_gb: Optional[int] = None,
40+
) -> WeightSynchronizer:
41+
"""Create the appropriate WeightSynchronizer for the given deployment.
42+
43+
Args:
44+
policy: Policy object (ColocatablePolicyInterface).
45+
generation: Generation object (GenerationInterface).
46+
generation_backend: Name of the generation backend ("vllm", "sglang", "megatron").
47+
colocated: Whether policy and generation share the same GPUs.
48+
train_cluster: RayVirtualCluster for training workers (required for non-colocated).
49+
inference_cluster: RayVirtualCluster for inference workers (required for non-colocated).
50+
refit_buffer_size_gb: Optional fixed buffer size for IPC weight staging.
51+
52+
Returns:
53+
A WeightSynchronizer instance appropriate for the deployment topology.
54+
55+
Raises:
56+
NotImplementedError: If the requested configuration is not supported.
57+
ValueError: If required arguments are missing.
58+
"""
59+
_SUPPORTED_BACKENDS = {VLLM_BACKEND, SGLANG_BACKEND, MEGATRON_BACKEND}
60+
if generation_backend not in _SUPPORTED_BACKENDS:
61+
raise ValueError(
62+
f"Unknown generation backend {generation_backend!r}. "
63+
f"Supported backends: {sorted(_SUPPORTED_BACKENDS)}"
64+
)
65+
66+
if colocated:
67+
if generation_backend == SGLANG_BACKEND:
68+
from nemo_rl.weight_sync.http_weight_synchronizer import (
69+
HTTPWeightSynchronizer,
70+
)
71+
72+
return HTTPWeightSynchronizer(
73+
policy=policy,
74+
generation=generation,
75+
)
76+
elif generation_backend in (VLLM_BACKEND, MEGATRON_BACKEND):
77+
from nemo_rl.weight_sync.ipc_weight_synchronizer import (
78+
IPCWeightSynchronizer,
79+
)
80+
81+
return IPCWeightSynchronizer(
82+
policy=policy,
83+
generation=generation,
84+
refit_buffer_size_gb=refit_buffer_size_gb,
85+
)
86+
else:
87+
if generation_backend == SGLANG_BACKEND:
88+
raise NotImplementedError(
89+
"SGLang does not support non-colocated inference mode."
90+
)
91+
if train_cluster is None or inference_cluster is None:
92+
raise ValueError(
93+
"train_cluster and inference_cluster are required "
94+
"for non-colocated weight synchronization."
95+
)
96+
97+
from nemo_rl.weight_sync.collective_weight_synchronizer import (
98+
CollectiveWeightSynchronizer,
99+
)
100+
101+
return CollectiveWeightSynchronizer(
102+
policy=policy,
103+
generation=generation,
104+
train_cluster=train_cluster,
105+
inference_cluster=inference_cluster,
106+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""HTTP weight synchronizer for colocated SGLang generation.
16+
17+
Handles weight transfer between a colocated policy and SGLang generation
18+
backend using HTTP streaming. SGLang exposes an HTTP endpoint for weight
19+
updates, so the policy streams weights directly to SGLang servers.
20+
21+
Lifecycle per sync:
22+
1. policy.offload_before_refit() -- free GPU for weight staging
23+
2. generation.prepare_for_generation(tags=["weights"]) -- allocate buffers
24+
3. generation.invalidate_kv_cache() -- clear stale KV cache
25+
4. policy.stream_weights_via_http() -- push weights via HTTP
26+
5. policy.offload_after_refit() -- restore optimizer state
27+
6. generation.prepare_for_generation(tags=["kv_cache"]) -- rebuild KV cache
28+
"""
29+
30+
from contextlib import nullcontext
31+
from typing import Any, Optional
32+
33+
import ray
34+
35+
from nemo_rl.utils.timer import Timer
36+
from nemo_rl.weight_sync.interfaces import WeightSynchronizer
37+
38+
39+
class HTTPWeightSynchronizer(WeightSynchronizer):
40+
"""Weight synchronizer using HTTP for colocated SGLang deployments.
41+
42+
Both the policy and generation workers run on the same GPUs. Weights
43+
are streamed to SGLang servers via their HTTP weight-update API.
44+
45+
Args:
46+
policy: Policy object implementing ColocatablePolicyInterface.
47+
generation: SGLangGeneration instance exposing get_sglang_url_to_gpu_uuids().
48+
"""
49+
50+
def __init__(self, policy: Any, generation: Any):
51+
self._policy = policy
52+
self._generation = generation
53+
self._stale = True
54+
55+
def sync_weights(
56+
self,
57+
*,
58+
timer: Optional[Timer] = None,
59+
kv_scales: Optional[dict[str, float]] = None,
60+
) -> None:
61+
self._policy.offload_before_refit()
62+
self._generation.prepare_for_generation(tags=["weights"])
63+
64+
timer_context = (
65+
timer.time("prepare_for_generation/transfer_and_update_weights")
66+
if timer is not None
67+
else nullcontext()
68+
)
69+
with timer_context:
70+
sglang_url_to_gpu_uuids = (
71+
self._generation.get_sglang_url_to_gpu_uuids()
72+
)
73+
74+
flush_success = self._generation.invalidate_kv_cache()
75+
if not flush_success:
76+
print(
77+
"SGLang KV cache invalidation failed before weight update. "
78+
)
79+
80+
futures_train = self._policy.stream_weights_via_http(
81+
sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids,
82+
)
83+
ray.get(futures_train)
84+
85+
self._policy.offload_after_refit()
86+
self._generation.prepare_for_generation(tags=["kv_cache"])
87+
self._stale = False
88+
89+
@property
90+
def is_stale(self) -> bool:
91+
return self._stale
92+
93+
def mark_stale(self) -> None:
94+
self._stale = True
95+
96+
def init_communicator(self) -> None:
97+
state_dict_info = self._policy.prepare_refit_info()
98+
self._generation.prepare_refit_info(state_dict_info)
99+
100+
def shutdown(self) -> None:
101+
pass

0 commit comments

Comments
 (0)