feat: add delta-compressed collective refit#2444
Conversation
|
Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
There was a problem hiding this comment.
Pull request overview
Adds an optional delta-compressed weight transfer protocol for non-colocated vLLM collective refit, enabling the trainer source rank to send full weights or additive deltas (dense / sparse_indices / sparse_bitmask) and apply deltas additively through existing vLLM weight loaders.
Changes:
- Introduces a delta-aware packed weight transfer protocol (
full/delta/done) with sparse delta encodings and a trainer-sideDeltaCompressionTrackerbaseline. - Integrates the new transfer path into DTensor v1/v2 and Megatron policy workers via a shared
dispatch_packed_weight_transfer(...)helper. - Updates vLLM collective refit to optionally consume the new full/delta protocol and adds unit tests + example configs.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
tests/unit/utils/test_weight_transfer.py |
Adds unit coverage for delta tracker behavior, sparse transports, additive load context, and producer/consumer roundtrips. |
nemo_rl/utils/weight_transfer.py |
Implements delta tracking, sparse encodings, packed full/delta broadcast protocol, and additive load context. |
nemo_rl/utils/weight_transfer_types.py |
Defines shared literal types/constants for delta compression and transfer kinds. |
nemo_rl/utils/torch_dtypes.py |
Centralizes dtype string→torch.dtype mappings (canonical + aliases). |
nemo_rl/models/policy/workers/megatron_policy_worker.py |
Switches collective weight broadcast to the delta-aware dispatcher when enabled. |
nemo_rl/models/policy/workers/dtensor_policy_worker.py |
Switches collective weight broadcast to the delta-aware dispatcher when enabled. |
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py |
Switches collective weight broadcast to the delta-aware dispatcher when enabled. |
nemo_rl/models/generation/vllm/vllm_worker.py |
Determines whether to use delta transfer and forwards that flag to the vLLM worker extension. |
nemo_rl/models/generation/vllm/vllm_worker_async.py |
Forwards the delta-transfer enablement flag in the async prepare_refit_info path. |
nemo_rl/models/generation/vllm/vllm_backend.py |
Adds delta-aware collective consumer path and additive-delta loading through existing loaders. |
nemo_rl/models/generation/vllm/config.py |
Extends vLLM generation config typing with delta_compression settings. |
nemo_rl/models/automodel/setup.py |
Reuses canonical dtype mapping from torch_dtypes instead of duplicating it. |
examples/configs/grpo_math_1B.yaml |
Documents/introduces the new delta_compression config block (disabled by default). |
examples/configs/distillation_math.yaml |
Documents/introduces the new delta_compression config block (disabled by default). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Awesome @HollowMan6 I found delta weight transfer has its own weight transfer function, which seems a duplicated one compared with the full weight transfer. It is out of the scope of this PR, but is there any block to have delt and full weight transfer shared the same communication function while have their independent protocol to pack, unpack the model weights? |
|
Thank you @ZhiyuLi-Nvidia for pointing this out, I just did some refactoring according to your suggestion, and it looks fine. |
fa0cb08 to
a3b3c70
Compare
4022b81 to
dd37c84
Compare
2bab874 to
6071cd5
Compare
9c50ebf to
43719b7
Compare
0b16704 to
3218f2a
Compare
b944c74 to
19f77c9
Compare
cbbc2b6 to
6ab0904
Compare
6ab0904 to
8794e01
Compare
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
| advantages = advantages.clamp(min=clip_low) | ||
| if clip_high is not None: | ||
| advantages = advantages.clamp(max=clip_high) | ||
| return advantages |
There was a problem hiding this comment.
LGTM — this is a well-structured, well-tested PR. The module decomposition is clean, test coverage is thorough across the new utilities, and the lazy initialization fallback in refit_policy_generation ensures the sync/async paths both work correctly even when the early-init optimization in _start_initial_policy_generation_weight_sync doesn't fire.
|
/ok to test 6c18684 |
|
Thank you for adding vllm http backend. Just curious how's that compared against zmq in perf / resilience? |
Signed-off-by: Hollow Man <hollowman@opensuse.org>
|
/ok to test 9a8d02b |
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
What does this PR do ?
Adds optional delta-compressed weight transfer for non-colocated vLLM collective refit.
This introduces a delta-aware packed weight transfer protocol that can send either full weights or additive deltas, with support for
dense,sparse_indices, andsparse_bitmaskdelta encodings. The trainer source rank keeps a pinned CPU baseline of the last successfully synced HF-format weights, computes deltas against that baseline, and periodically sends full syncs based onfull_sync_interval.The feature is disabled by default and only applies to non-colocated vLLM refit. Colocated CUDA IPC, vLLM FP8 weights, and ModelOpt quantized vLLM paths are rejected.
Note that the receiver does not keep a baseline. It receives actual additive deltas and applies them directly to the current vLLM model weights. The pinned CPU baseline exists only on the trainer source side so the next current - baseline can be computed.
flowchart LR subgraph Trainer["Policy trainer ranks"] I["state_dict iterator<br/>DTensor.full_tensor / HF export"] --> C["_next_chunk"] C --> R{"group.rank == src?"} R -- "no" --> NR["consume iterator for collective correctness<br/>receive broadcast metadata/payload only"] R -- "yes" --> T{"DeltaCompressionTracker?"} T -- "no" --> FULL["FULL update<br/>dense tensors"] T -- "yes" --> P["prepare_chunk"] P --> S{"full sync due?"} S -- "yes" --> FULL S -- "no" --> D["delta = current - pinned CPU baseline"] D --> ENC{"transport"} ENC -- "dense" --> DENSE["dense delta payload"] ENC -- "sparse_indices / sparse_bitmask" --> SPARSE["sparse encode<br/>indices or bitmask + values + metadata"] SPARSE --> BUCKET["sparse bucket consolidation"] FULL --> PACK["pack_named_tensors<br/>uint8 payload + header"] DENSE --> PACK BUCKET --> PACK PACK --> NCCL["NCCL broadcast<br/>shared full/delta protocol"] P --> BASE["async baseline snapshot<br/>pinned CPU via D2H stream"] end subgraph VLLM["Non-colocated vLLM generation workers"] NCCL --> RECV["packed_weight_transfer_consumer"] RECV --> KIND{"header.kind"} KIND -- "FULL" --> LOADFULL["vLLM load_weights<br/>overwrite model weights"] KIND -- "DELTA dense" --> LOADDELTA["batch decoded deltas"] KIND -- "DELTA sparse" --> DECODE["decode sparse to dense delta tensors"] DECODE --> LOADDELTA LOADDELTA --> ADDCTX["additive_weight_load_context"] ADDCTX --> APPLY["vLLM load_weights<br/>copy/fill/setitem becomes add_"] endsequenceDiagram participant Src as Trainer src rank participant Peer as Trainer non-src ranks participant Comm as NCCL group participant V as vLLM worker participant Base as Pinned CPU baseline Src->>Src: gather/export next chunk Peer->>Peer: gather/export same chunk for trainer-side collectives Src->>Base: wait old baseline event, read baseline Src->>Src: prepare full or delta chunk Src->>Base: async snapshot current weights for next refit Src->>Src: sparse encode and bucket if useful Src->>Comm: broadcast header + packed uint8 payload Peer->>Comm: participate in broadcast receive V->>Comm: receive header + payload alt FULL update V->>V: unpack dense tensors V->>V: load_weights overwrite else DELTA update V->>V: unpack or sparse-decode delta tensors V->>V: batch delta loads V->>V: load_weights under additive context end Src->>Src: on_sync_succeeded increments committed sync countIssues
N/A
Usage
Enable under the vLLM generation config:
Before your PR is "Ready for review"
Pre checks:
Additional Information
DeltaCompressionTrackerand delta-aware packed transfer utilities.fullordeltachunks.dispatch_packed_weight_transfer(...)helper.