Skip to content

Commit d9bb6c4

Browse files
yeyu-nvidiaclaude
andcommitted
fix: add divisibility guard and clarify WORLD_SIZE fallback
Address review feedback: - Add ValueError if world_size is not divisible by dp_shard_size * cp_size - Comment that torch.cuda.device_count() is per-node, not world_size Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 83226f4 commit d9bb6c4

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

  • examples/speculative_decoding

examples/speculative_decoding/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,16 @@ def train():
214214
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
215215
# Auto-compute dp_replicate_size so that
216216
# dp_replicate_size * dp_shard_size * cp_size == world_size.
217+
# Note: torch.cuda.device_count() returns per-node GPU count, not world_size.
218+
# WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total.
217219
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
218220
parallel_size = training_args.dp_shard_size * training_args.cp_size
221+
if world_size % parallel_size != 0:
222+
raise ValueError(
223+
f"world_size ({world_size}) must be divisible by "
224+
f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) "
225+
f"= {parallel_size}"
226+
)
219227
dp_replicate_size = world_size // parallel_size
220228
training_args.parallelism_config = ParallelismConfig(
221229
cp_size=training_args.cp_size,

0 commit comments

Comments
 (0)