Skip to content

Commit 83226f4

Browse files
yeyu-nvidiaclaude
andcommitted
fix: auto-compute dp_replicate_size from world_size in ParallelismConfig
When dp_shard_size < world_size (e.g., dp_shard_size=4 on 8 GPUs), ParallelismConfig raises "total_size does not match num_processes" because dp_replicate_size defaults to 1. Auto-compute dp_replicate_size = world_size // (dp_shard_size * cp_size) so that intra-node FSDP2 sharding + inter-node data-parallel replication works without requiring users to manually set dp_replicate_size. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 289a239 commit 83226f4

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

  • examples/speculative_decoding

examples/speculative_decoding/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,15 @@ def train():
212212
"Either data.data_path or data.offline_data_path must be set in the config."
213213
)
214214
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
215+
# Auto-compute dp_replicate_size so that
216+
# dp_replicate_size * dp_shard_size * cp_size == world_size.
217+
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
218+
parallel_size = training_args.dp_shard_size * training_args.cp_size
219+
dp_replicate_size = world_size // parallel_size
215220
training_args.parallelism_config = ParallelismConfig(
216-
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
221+
cp_size=training_args.cp_size,
222+
dp_shard_size=training_args.dp_shard_size,
223+
dp_replicate_size=dp_replicate_size,
217224
)
218225
if training_args.cp_size > 1:
219226
patch_ring_attention_for_ttt()

0 commit comments

Comments
 (0)