|
25 | 25 | rebuild_cuda_tensor_from_ipc, |
26 | 26 | ) |
27 | 27 | from nemo_rl.utils.nsys import wrap_with_nvtx_name |
28 | | -from nemo_rl.utils.packed_tensor import packed_broadcast_consumer |
29 | 28 | from nemo_rl.utils.weight_transfer import ( |
30 | 29 | additive_weight_load_context, |
31 | 30 | packed_weight_transfer_consumer, |
@@ -58,7 +57,7 @@ def fix_gpt_oss_export_transpose(key: str, weight: torch.Tensor) -> torch.Tensor |
58 | 57 |
|
59 | 58 | class VllmInternalWorkerExtension: |
60 | 59 | state_dict_info: dict[str, Any] | None = None |
61 | | - use_delta_weight_transfer: bool = False |
| 60 | + delta_load_batch_size_bytes: int | None = None |
62 | 61 |
|
63 | 62 | def init_collective( |
64 | 63 | self, |
@@ -109,18 +108,21 @@ def maybe_init_zmq(self): |
109 | 108 | def prepare_refit_info( |
110 | 109 | self, |
111 | 110 | state_dict_info: dict[str, Any], |
112 | | - use_delta_weight_transfer: bool, |
| 111 | + delta_load_batch_size_bytes: int | None = None, |
113 | 112 | ) -> None: |
114 | | - """Prepare state dict metadata for weight refitting and IPC streaming. |
| 113 | + """Prepare state dict metadata for IPC/ZMQ weight refitting. |
| 114 | +
|
| 115 | + Collective refit receives tensor metadata from the transfer headers. |
115 | 116 |
|
116 | 117 | Args: |
117 | 118 | state_dict_info (dict): A dictionary containing the info for refit. |
118 | 119 | e.g. {tensor_name: (shape, dtype)} |
119 | | - use_delta_weight_transfer (bool): Whether collective refit receives |
120 | | - full weights only or the delta-aware full/delta protocol. |
| 120 | + delta_load_batch_size_bytes (int | None): Maximum decoded delta bytes |
| 121 | + to batch before calling vLLM load_weights. None means delta |
| 122 | + transfer is disabled. |
121 | 123 | """ |
122 | 124 | self.state_dict_info = state_dict_info |
123 | | - self.use_delta_weight_transfer = use_delta_weight_transfer |
| 125 | + self.delta_load_batch_size_bytes = delta_load_batch_size_bytes |
124 | 126 |
|
125 | 127 | def _maybe_process_fp8_kv_cache(self) -> None: |
126 | 128 | """Process weights after loading for FP8 KV cache (static scales).""" |
@@ -332,28 +334,15 @@ def update_weights_via_ipc_zmq(self) -> bool: |
332 | 334 | ) |
333 | 335 | def update_weights_from_collective(self) -> bool: |
334 | 336 | """Update the model weights from collective communication.""" |
335 | | - state_dict_info = self.state_dict_info |
336 | | - assert state_dict_info is not None, ( |
337 | | - "state_dict_info is not prepared. " |
338 | | - "Please call prepare_refit_info when initializing the worker." |
339 | | - ) |
340 | | - |
341 | 337 | try: |
342 | | - if not self.use_delta_weight_transfer: |
343 | | - packed_broadcast_consumer( |
344 | | - iterator=iter(state_dict_info.items()), |
345 | | - group=self.model_update_group, |
346 | | - src=0, |
347 | | - post_unpack_func=self._load_weights, |
348 | | - ) |
349 | | - else: |
350 | | - packed_weight_transfer_consumer( |
351 | | - group=self.model_update_group, |
352 | | - src=0, |
353 | | - load_full_weights_func=self._load_weights, |
354 | | - load_delta_weights_func=self._load_weight_deltas, |
355 | | - device=self.device, |
356 | | - ) |
| 338 | + packed_weight_transfer_consumer( |
| 339 | + group=self.model_update_group, |
| 340 | + src=0, |
| 341 | + load_full_weights_func=self._load_weights, |
| 342 | + load_delta_weights_func=self._load_weight_deltas, |
| 343 | + device=self.device, |
| 344 | + delta_load_batch_size_bytes=self.delta_load_batch_size_bytes, |
| 345 | + ) |
357 | 346 |
|
358 | 347 | # Process weights after loading for FP8 KV cache |
359 | 348 | self._maybe_process_fp8_kv_cache() |
|
0 commit comments