diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 9e32098dd29..0e05c11e52a 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -563,7 +563,7 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n) _broadcast(n_tensor) if n == 0: