Skip to content

Commit 5b1f8c8

Browse files
authored
Enhance tensor model parallel rank retrieval
Add check for model parallel rank in mpu.
1 parent f7c5d75 commit 5b1f8c8

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

deepspeed/utils/bwc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,13 @@ def bwc_tensor_model_parallel_rank(mpu=None):
3737
elif hasattr(mpu, 'get_slice_parallel_rank'):
3838
# Some DeepSpeed + pipeline parallelism versions
3939
return mpu.get_slice_parallel_rank()
40-
else:
40+
elif hasattr(mpu, 'get_model_parallel_rank'):
4141
# Deprecated Megatron and DeepSpeed convention
4242
return mpu.get_model_parallel_rank()
43+
else:
44+
# mpu does not provide any known tensor/model-parallel rank API.
45+
# Treat as "no tensor model parallelism".
46+
return 0
4347

4448

4549
def bwc_tensor_model_parallel_world_size(mpu=None):

0 commit comments

Comments
 (0)