1- from typing import cast
2-
31import torch
42import torch .distributed as dist
53
@@ -10,6 +8,10 @@ def foreach_all_gather(
108 params : list [torch .Tensor ],
119 group : dist .ProcessGroup | None ,
1210) -> list [list [torch .Tensor ]]:
11+ """Perform a fused all-gather on a list of tensors.
12+
13+ All ranks must contribute tensors with identical numels and shapes.
14+ """
1315 if group is None :
1416 group = dist .group .WORLD
1517
@@ -18,29 +20,23 @@ def foreach_all_gather(
1820
1921 input_tensor_numels = [param .numel () for param in params ]
2022 input_tensor_shapes = [param .shape for param in params ]
23+ world_size = dist .get_world_size (group )
24+ local_tensor_size = sum (input_tensor_numels )
25+ global_tensor_size = local_tensor_size * world_size
2126
22- flatten_copyin_tensor = torch .empty ((sum (input_tensor_numels ),), dtype = param0 .dtype , device = param0 .device )
27+ # prepare flatten tensor
28+ flatten_copyin_tensor = torch .empty ((local_tensor_size ,), dtype = param0 .dtype , device = param0 .device )
2329 splits_copyin_tensor = torch .split (flatten_copyin_tensor , input_tensor_numels )
2430 torch ._foreach_copy_ (splits_copyin_tensor , [p .flatten () for p in params ])
31+ flatten_copyout_tensor = torch .empty ((global_tensor_size ,), dtype = param0 .dtype , device = param0 .device )
2532
26- input_tensor_numels_tensor = torch .tensor (input_tensor_numels , dtype = torch .int64 , device = param0 .device )
27- global_input_tensor_numels = [
28- torch .zeros_like (input_tensor_numels_tensor ) for _ in range (dist .get_world_size (group ))
29- ]
30-
31- dist .all_gather (global_input_tensor_numels , input_tensor_numels_tensor , group = group )
32- copyout_size = int (sum (sum (i ) for i in global_input_tensor_numels ))
33- flatten_copyout_tensor = torch .empty ((copyout_size ,), dtype = param0 .dtype , device = param0 .device )
34-
33+ # allgather global flatten tensor
3534 dist .all_gather_into_tensor (flatten_copyout_tensor , flatten_copyin_tensor , group = group )
36- copyout_split_size : list [int ] = sum ([ i . tolist () for i in global_input_tensor_numels ], [])
35+ copyout_split_size : list [int ] = input_tensor_numels * world_size
3736 splits_copyout_tensor = torch .split (flatten_copyout_tensor , copyout_split_size )
37+ global_input_tensor_shapes = input_tensor_shapes * world_size
3838
39- _global_input_tensor_shapes : list [None ] | list [list [tuple ]] = [None for _ in range (dist .get_world_size (group ))]
40- dist .all_gather_object (_global_input_tensor_shapes , input_tensor_shapes , group = group )
41- _global_input_tensor_shapes = cast (list [list [tuple ]], _global_input_tensor_shapes )
42- global_input_tensor_shapes : list [tuple ] = sum (_global_input_tensor_shapes , [])
43-
39+ # gathered_params: [[params1/p, params1/p,...], [params2/p, params2/p,...], ...]
4440 gathered_params : list [list [torch .Tensor ]] = []
4541 for i in range (len (params )):
4642 single_gathered_params : list [torch .Tensor ] = []
0 commit comments