Skip to content

Commit fa0cb08

Browse files
committed
Refactor and Optimize
Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent f08d20d commit fa0cb08

12 files changed

Lines changed: 1146 additions & 652 deletions

File tree

examples/configs/distillation_math.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ policy: &POLICY_BASE
206206
dtype: ${...precision}
207207
transport: "sparse_indices" # dense, sparse_indices, or sparse_bitmask
208208
full_sync_interval: 20
209+
sparse_bucket_size_bytes: 5368709120 # 5 GiB
210+
delta_load_batch_size_bytes: 536870912 # 512 MiB
209211

210212
colocated:
211213
# true: generation shares training GPUs

examples/configs/grpo_math_1B.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ policy:
314314
dtype: ${policy.precision}
315315
transport: "sparse_indices" # dense, sparse_indices, or sparse_bitmask
316316
full_sync_interval: 20
317+
sparse_bucket_size_bytes: 5368709120 # 5 GiB
318+
delta_load_batch_size_bytes: 536870912 # 512 MiB
317319
colocated:
318320
# true: generation shares training GPUs
319321
# false: uses dedicated generation resources

nemo_rl/models/generation/vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ class VllmDeltaCompressionConfig(TypedDict):
6262
# Number of successful refits between full baseline refreshes.
6363
# Recommended default: 20.
6464
full_sync_interval: int
65+
# Maximum sparse-encoded payload bytes to bucket before broadcasting.
66+
# Smaller values improve refit pipelining; larger values reduce broadcast
67+
# call overhead. Recommended default: 5368709120 (5 GiB).
68+
sparse_bucket_size_bytes: int
69+
# Maximum decoded delta tensor bytes to batch before calling vLLM load_weights.
70+
# Smaller values improve overlap with receives; larger values reduce loader
71+
# call overhead. Recommended default: 536870912 (512 MiB).
72+
delta_load_batch_size_bytes: int
6573

6674

6775
class VllmConfig(GenerationConfig):

nemo_rl/models/generation/vllm/vllm_backend.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
rebuild_cuda_tensor_from_ipc,
2626
)
2727
from nemo_rl.utils.nsys import wrap_with_nvtx_name
28-
from nemo_rl.utils.packed_tensor import packed_broadcast_consumer
2928
from nemo_rl.utils.weight_transfer import (
3029
additive_weight_load_context,
3130
packed_weight_transfer_consumer,
@@ -58,7 +57,7 @@ def fix_gpt_oss_export_transpose(key: str, weight: torch.Tensor) -> torch.Tensor
5857

5958
class VllmInternalWorkerExtension:
6059
state_dict_info: dict[str, Any] | None = None
61-
use_delta_weight_transfer: bool = False
60+
delta_load_batch_size_bytes: int | None = None
6261

6362
def init_collective(
6463
self,
@@ -109,18 +108,21 @@ def maybe_init_zmq(self):
109108
def prepare_refit_info(
110109
self,
111110
state_dict_info: dict[str, Any],
112-
use_delta_weight_transfer: bool,
111+
delta_load_batch_size_bytes: int | None = None,
113112
) -> None:
114-
"""Prepare state dict metadata for weight refitting and IPC streaming.
113+
"""Prepare state dict metadata for IPC/ZMQ weight refitting.
114+
115+
Collective refit receives tensor metadata from the transfer headers.
115116
116117
Args:
117118
state_dict_info (dict): A dictionary containing the info for refit.
118119
e.g. {tensor_name: (shape, dtype)}
119-
use_delta_weight_transfer (bool): Whether collective refit receives
120-
full weights only or the delta-aware full/delta protocol.
120+
delta_load_batch_size_bytes (int | None): Maximum decoded delta bytes
121+
to batch before calling vLLM load_weights. None means delta
122+
transfer is disabled.
121123
"""
122124
self.state_dict_info = state_dict_info
123-
self.use_delta_weight_transfer = use_delta_weight_transfer
125+
self.delta_load_batch_size_bytes = delta_load_batch_size_bytes
124126

125127
def _maybe_process_fp8_kv_cache(self) -> None:
126128
"""Process weights after loading for FP8 KV cache (static scales)."""
@@ -332,28 +334,15 @@ def update_weights_via_ipc_zmq(self) -> bool:
332334
)
333335
def update_weights_from_collective(self) -> bool:
334336
"""Update the model weights from collective communication."""
335-
state_dict_info = self.state_dict_info
336-
assert state_dict_info is not None, (
337-
"state_dict_info is not prepared. "
338-
"Please call prepare_refit_info when initializing the worker."
339-
)
340-
341337
try:
342-
if not self.use_delta_weight_transfer:
343-
packed_broadcast_consumer(
344-
iterator=iter(state_dict_info.items()),
345-
group=self.model_update_group,
346-
src=0,
347-
post_unpack_func=self._load_weights,
348-
)
349-
else:
350-
packed_weight_transfer_consumer(
351-
group=self.model_update_group,
352-
src=0,
353-
load_full_weights_func=self._load_weights,
354-
load_delta_weights_func=self._load_weight_deltas,
355-
device=self.device,
356-
)
338+
packed_weight_transfer_consumer(
339+
group=self.model_update_group,
340+
src=0,
341+
load_full_weights_func=self._load_weights,
342+
load_delta_weights_func=self._load_weight_deltas,
343+
device=self.device,
344+
delta_load_batch_size_bytes=self.delta_load_batch_size_bytes,
345+
)
357346

358347
# Process weights after loading for FP8 KV cache
359348
self._maybe_process_fp8_kv_cache()

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from nemo_rl.models.huggingface.common import ModelFlag
3636
from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled
3737
from nemo_rl.utils.nsys import wrap_with_nvtx_name
38-
from nemo_rl.utils.weight_transfer import get_vllm_delta_transfer_config
3938

4039

4140
# Use a base class to share some functions to avoid code duplication.
@@ -136,9 +135,6 @@ def __init__(
136135
the vLLM worker subprocess (e.g. for quantization configs).
137136
"""
138137
self.cfg = config
139-
self.use_delta_weight_transfer = (
140-
get_vllm_delta_transfer_config(self.cfg) is not None
141-
)
142138
self.model_name = self.cfg["model_name"]
143139
self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"]
144140
self.pipeline_parallel_size = self.cfg["vllm_cfg"]["pipeline_parallel_size"]
@@ -666,6 +662,17 @@ def _get_raw_spec_counters(self) -> dict[str, float | list[float]]:
666662
metrics[metric.name] = metric.value
667663
return metrics
668664

665+
def _get_delta_load_batch_size_bytes(self) -> int | None:
666+
delta_config = self.cfg.get("delta_compression", None)
667+
if delta_config is None or not delta_config["enabled"]:
668+
return None
669+
delta_load_batch_size_bytes = int(delta_config["delta_load_batch_size_bytes"])
670+
if delta_load_batch_size_bytes < 1:
671+
raise ValueError(
672+
"delta_compression.delta_load_batch_size_bytes must be >= 1"
673+
)
674+
return delta_load_batch_size_bytes
675+
669676

670677
class VllmGenerationWorkerImpl(BaseVllmGenerationWorker):
671678
def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
@@ -912,7 +919,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
912919
"""Prepare the info for refit."""
913920
self.llm.collective_rpc(
914921
"prepare_refit_info",
915-
args=(state_dict_info, self.use_delta_weight_transfer),
922+
args=(state_dict_info, self._get_delta_load_batch_size_bytes()),
916923
)
917924

918925
@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_via_ipc_zmq")

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non
10391039
"""Async version of prepare_refit_info."""
10401040
await self.llm.collective_rpc(
10411041
"prepare_refit_info",
1042-
args=(state_dict_info, self.use_delta_weight_transfer),
1042+
args=(state_dict_info, self._get_delta_load_batch_size_bytes()),
10431043
)
10441044

10451045
async def update_weights_via_ipc_zmq_async(

nemo_rl/models/policy/workers/dtensor_policy_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
from nemo_rl.utils.nsys import wrap_with_nvtx_name
9494
from nemo_rl.utils.weight_transfer import (
9595
create_vllm_delta_transfer_tracker,
96-
dispatch_packed_weight_transfer,
96+
packed_weight_transfer_producer,
9797
)
9898

9999

@@ -1877,7 +1877,7 @@ def _params_iterator():
18771877
yield name, _dtensor_post_iter_func(tensor, self.dtype)
18781878

18791879
params_iterator = _params_iterator()
1880-
dispatch_packed_weight_transfer(
1880+
packed_weight_transfer_producer(
18811881
iterator=params_iterator,
18821882
group=self.model_update_group,
18831883
src=0,

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from nemo_rl.utils.nsys import wrap_with_nvtx_name
7575
from nemo_rl.utils.weight_transfer import (
7676
create_vllm_delta_transfer_tracker,
77-
dispatch_packed_weight_transfer,
77+
packed_weight_transfer_producer,
7878
)
7979

8080

@@ -979,7 +979,7 @@ def broadcast_weights_for_collective(
979979
self.model = self.move_to_cuda(self.model)
980980

981981
params_iterator = dtensor_params_generator(self.model, self.dtype)
982-
dispatch_packed_weight_transfer(
982+
packed_weight_transfer_producer(
983983
iterator=params_iterator,
984984
group=self.model_update_group,
985985
src=0,

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from nemo_rl.utils.nsys import wrap_with_nvtx_name
9393
from nemo_rl.utils.weight_transfer import (
9494
create_vllm_delta_transfer_tracker,
95-
dispatch_packed_weight_transfer,
95+
packed_weight_transfer_producer,
9696
)
9797

9898
TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)
@@ -1125,7 +1125,7 @@ def broadcast_weights_for_collective(
11251125
) -> None:
11261126
"""Broadcast the weights for collective communication."""
11271127
params_iterator = self._iter_params_with_optional_kv_scales(kv_scales=kv_scales)
1128-
dispatch_packed_weight_transfer(
1128+
packed_weight_transfer_producer(
11291129
iterator=params_iterator,
11301130
group=self.model_update_group,
11311131
src=0,

nemo_rl/utils/torch_dtypes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,14 @@
3737
**G_CANONICAL_FLOAT_DTYPE_MAP,
3838
"float64": torch.float64,
3939
}
40+
41+
for _float8_dtype_name in (
42+
"float8_e4m3fn",
43+
"float8_e5m2",
44+
"float8_e4m3fnuz",
45+
"float8_e5m2fnuz",
46+
):
47+
_float8_dtype = getattr(torch, _float8_dtype_name, None)
48+
if _float8_dtype is not None:
49+
G_TENSOR_DTYPE_MAP[_float8_dtype_name] = _float8_dtype
50+
del _float8_dtype, _float8_dtype_name

0 commit comments

Comments
 (0)