Skip to content

Commit 2fef374

Browse files
yeyu-nvidiaclaude
andauthored
fix: auto-compute dp_replicate_size from world_size (#1302)
## Summary - When `dp_shard_size < world_size` (e.g., `dp_shard_size=4` on 8 GPUs across 2 nodes), `ParallelismConfig` raises `total_size (4) does not match num_processes (8)` because `dp_replicate_size` defaults to 1 - Auto-compute `dp_replicate_size = world_size // (dp_shard_size * cp_size)` so intra-node FSDP2 sharding + inter-node data-parallel replication works without manual config - This enables `dp_shard_size` to be set to per-node GPU count (better NVLink utilization) while automatically creating replicas across nodes ## Test plan - [ ] Verify single-node training (dp_shard_size == world_size, dp_replicate_size == 1) unchanged - [ ] Verify multi-node with dp_shard_size < world_size creates correct replica groups - [ ] Verify existing EAGLE3/DFlash configs still work 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Enhanced parallelism configuration initialization in the speculative decoding example to better handle distributed training scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Ye Yu <yeyu@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 355c6b7 commit 2fef374

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

  • examples/speculative_decoding

examples/speculative_decoding/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,23 @@ def train():
212212
"Either data.data_path or data.offline_data_path must be set in the config."
213213
)
214214
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
215+
# Auto-compute dp_replicate_size so that
216+
# dp_replicate_size * dp_shard_size * cp_size == world_size.
217+
# Note: torch.cuda.device_count() returns per-node GPU count, not world_size.
218+
# WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total.
219+
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
220+
parallel_size = training_args.dp_shard_size * training_args.cp_size
221+
if world_size % parallel_size != 0:
222+
raise ValueError(
223+
f"world_size ({world_size}) must be divisible by "
224+
f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) "
225+
f"= {parallel_size}"
226+
)
227+
dp_replicate_size = world_size // parallel_size
215228
training_args.parallelism_config = ParallelismConfig(
216-
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
229+
cp_size=training_args.cp_size,
230+
dp_shard_size=training_args.dp_shard_size,
231+
dp_replicate_size=dp_replicate_size,
217232
)
218233
if training_args.cp_size > 1:
219234
patch_ring_attention_for_ttt()

0 commit comments

Comments
 (0)