-
Notifications
You must be signed in to change notification settings - Fork 33.2k
Expand file tree
/
Copy pathtensor_parallel.py
More file actions
1611 lines (1310 loc) · 66.7 KB
/
tensor_parallel.py
File metadata and controls
1611 lines (1310 loc) · 66.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import operator
import os
import re
from functools import reduce
from ..distributed import DistributedConfig
from ..utils import is_torch_greater_or_equal, logging
from ..utils.generic import GeneralInterface
from ..utils.import_utils import is_torch_available
if is_torch_available():
import torch
import torch.distributed as dist
from torch import nn
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
logger = logging.get_logger(__name__)
def to_local(t):
"""Unwrap a `DTensor` to its local shard if needed; pass through otherwise.
Custom kernels (CUTLASS, CuteDSL, Triton) take raw tensor pointers and don't
understand `DTensor`, so weights wrapped by FSDP2 / EP need this unwrap before
they can be fed to the kernel. ``to_local()`` is autograd-aware on the train
path: backward rewraps the gradient as a DTensor matching each parameter's
placements.
"""
if is_torch_available() and isinstance(t, torch.distributed.tensor.DTensor):
return t.to_local()
return t
def initialize_tensor_parallelism(
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
):
r"""
Sets up the device mesh and initialized the backend for tensor parallelism.
This function is called when the model is loaded and the TP plan is set to 'auto'.
"""
if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and device_map is not None:
raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
if device_mesh is None:
if not is_torch_greater_or_equal("2.5"):
raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if device_type == "mps":
raise RuntimeError("Tensor parallelism is not supported on MPS devices.")
current_device = getattr(torch, device_type)
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl", "neuron": "neuron"}
backend = backend_map.get(device_type)
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
current_device = getattr(torch, device_type)
if device_type != "cpu":
current_device.set_device(local_rank)
except Exception as e:
raise OSError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan`."
) from e
if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device()
tp_device = torch.device(device_type, index)
device_map = tp_device
else:
tp_device = torch.device(device_type)
device_map = device_type or {}
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
else:
if device_mesh.ndim > 1:
if "tp" not in device_mesh.mesh_dim_names:
raise ValueError(
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
"Please provide a valid `device_mesh`."
)
device_mesh = device_mesh["tp"]
tp_size = device_mesh.size()
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
return device_map, device_mesh, tp_size
def replace_layer_number_by_wildcard(name: str) -> str:
"""
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
a dot (`.`) and the end of the string.
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
"""
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
"""
Get the TP style for a parameter from the TP plan.
The TP plan is a dictionary that maps parameter names to TP styles.
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parent classes for `post_init` calls
"""
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
return tp_plan[module_name]
return None
# =============================================================================
# Tensor Sharding Utilities
# =============================================================================
if is_torch_available():
str_to_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (list[int]).
In the second case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
So if you have: gate_proj ( 16, 5120, 8190)
and up_proj ( 16, 5120, 8190)
packed as gate_up_proj ( 16, 5120, 2 * 8190)
And you shard along the last dimension, you need to interleave the gate and up values:
Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
Let's take TP_size = 4 for an example:
Packed tensor `gate_up_proj`
---------------------------------------------------------------
[ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑
Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
Explanation:
- The first half of the tensor (left of the center) holds the gate_proj values.
- The second half (right of the center) holds the up_proj values.
- For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
- Each shard receives one slice from the gate part and the corresponding slice from the up part.
For instance:
• Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
• Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
• … and so on.
This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
"""
slice_ = param
total_size = empty_param.shape[dim]
world_size = device_mesh.size()
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
tensors_slices = []
block_offset = 0
for block_size in block_sizes:
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
slice_dtype = slice_.get_dtype()
# Handle F8_E4M3 dtype by converting to float16 before slicing
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
casted = False
if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
slice_ = slice_[...].to(torch.float16)
casted = True
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
tensor = slice_[:, tensors_slices, ...]
elif dim == 2 or dim == -1:
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
if casted:
return tensor
else:
return tensor.to(str_to_dtype[slice_dtype])
def repack_weights(
packed_parameter: torch.Tensor,
sharded_dim: int, # The dimension index in the global tensor that was sharded
world_size: int,
num_blocks: int = 2,
) -> torch.Tensor:
"""
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
This is an inverse operation to get_packed_weights.
Args:
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
world_size: The tensor parallel world size.
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
Returns:
The reordered tensor in canonical packed format.
"""
if num_blocks != 2:
raise ValueError(
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
)
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
shard_chunk_size = original_block_size_on_dim // world_size
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
tensor_view = packed_parameter.view(
*prefix_shape,
world_size,
num_blocks,
shard_chunk_size,
*suffix_shape,
)
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
# This groups all chunks of G together, then all chunks of U together.
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
axis_ws_abs = len(prefix_shape)
axis_npp_abs = len(prefix_shape) + 1
permute_order = list(range(tensor_view.ndim))
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
tensor_permuted = tensor_view.permute(*permute_order)
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
# The final shape should be the same as reconstructed_tensor.
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int | None = None):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
Case (1)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 4
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
Case (2)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 14
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
rank 8 gets (0, 5120, 8190)
rank 9 gets (0, 5120, 8190)
rank 10 gets (0, 5120, 8190)
rank 11 gets (0, 5120, 8190)
rank 12 gets (0, 5120, 8190)
rank 13 gets (0, 5120, 8190)
Case (3)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 3
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
Args:
param (torch.Tensor): The tensor to shard.
empty_param (torch.Tensor): A tensor used for shape reference.
device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.ndim
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
# Get param shape: works for both torch.Tensor and safetensors TensorInfo
param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
if dim < 0:
dim = param_dim + dim
if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
dim = 0
elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
dim = 1
shard_size = math.ceil(param_shape[dim] / world_size)
start = rank * shard_size
end = min(start + shard_size, param_shape[dim])
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
# we have the full tensor not 1 part of it.
# in that case, we just assume that the weight was properly saved
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
# here we take care of potential chunking / layer split / layer chunking.
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
# actually we still shard dim=0 does not change
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
# tensor on a certain device (with the input tensor_index)
if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
# special case we don't "shard" just send this entire tensor to the correct rank.
if start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:]
else:
return torch.empty([], dtype=torch.int64, device=rank)
slice_indices = [slice(None)] * len(param_shape)
if start < param_shape[dim]:
slice_indices[dim] = slice(start, end)
param = param[tuple(slice_indices)]
if isinstance(param, list): # TODO handle the modulelist case!
param = [p[:] for p in param]
return param
param_shape[dim] = 0
return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
def _split_along_last_dim(x, world_size):
"""Split tensor along last dimension into world_size chunks."""
return torch.chunk(x, world_size, dim=-1)
# =============================================================================
# Distributed Communication Primitives
# =============================================================================
#
# Naming convention:
# - Functions describe their FORWARD behavior
# - Backward behavior is the "conjugate" operation for gradient flow
#
# Available operations:
# ┌────────────────────┬─────────────────────┬─────────────────────┐
# │ Function │ Forward │ Backward │
# ├────────────────────┼─────────────────────┼─────────────────────┤
# │ all_reduce │ all-reduce (sum) │ identity │
# │ all_reduce_backward│ identity │ all-reduce (sum) │
# │ all_gather │ all-gather │ split (local chunk) │
# │ split │ split (local chunk) │ all-gather │
# │ reduce_scatter │ reduce-scatter │ all-gather │
# └────────────────────┴─────────────────────┴─────────────────────┘
# ===================
class _AllReduceBackward(torch.autograd.Function):
"""Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
return x
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
if device_mesh.size() == 1:
return grad_output, None
grad_output = grad_output.contiguous()
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
return grad_output, None
class _AllReduceForward(torch.autograd.Function):
"""All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
@staticmethod
def forward(ctx, x, device_mesh):
if device_mesh.size() == 1:
return x
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _AllGather(torch.autograd.Function):
"""All-gather forward, split backward. Gathers sharded outputs."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
world_size = device_mesh.size()
if world_size == 1:
return x
last_dim = x.dim() - 1
rank = device_mesh.get_local_rank()
group = device_mesh.get_group()
x = x.contiguous()
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
tensor_list[rank] = x
dist.all_gather(tensor_list, x, group=group)
return torch.cat(tensor_list, dim=last_dim).contiguous()
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
world_size = device_mesh.size()
if world_size == 1:
return grad_output, None
rank = device_mesh.get_local_rank()
chunks = _split_along_last_dim(grad_output, world_size)
return chunks[rank].contiguous(), None
class _Split(torch.autograd.Function):
"""Split forward, all-gather backward. Scatters replicated input."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
world_size = device_mesh.size()
if world_size == 1:
return x
rank = device_mesh.get_local_rank()
chunks = _split_along_last_dim(x, world_size)
return chunks[rank].contiguous()
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
world_size = device_mesh.size()
if world_size == 1:
return grad_output, None
last_dim = grad_output.dim() - 1
rank = device_mesh.get_local_rank()
group = device_mesh.get_group()
grad_output = grad_output.contiguous()
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
tensor_list[rank] = grad_output
dist.all_gather(tensor_list, grad_output, group=group)
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
class _ReduceScatter(torch.autograd.Function):
"""Reduce-scatter forward, all-gather backward. For sequence parallel."""
@staticmethod
def forward(ctx, x, device_mesh):
ctx.device_mesh = device_mesh
world_size = device_mesh.size()
if world_size == 1:
return x
last_dim = x.dim() - 1
group = device_mesh.get_group()
input_chunks = list(x.chunk(world_size, dim=last_dim))
output_shape = list(x.shape)
output_shape[last_dim] //= world_size
output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
return output
@staticmethod
def backward(ctx, grad_output):
device_mesh = ctx.device_mesh
world_size = device_mesh.size()
if world_size == 1:
return grad_output, None
last_dim = grad_output.dim() - 1
rank = device_mesh.get_local_rank()
group = device_mesh.get_group()
grad_output = grad_output.contiguous()
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
tensor_list[rank] = grad_output
dist.all_gather(tensor_list, grad_output, group=group)
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
# =============================================================================
# Convenience wrappers
# =============================================================================
def all_reduce_backward(x, device_mesh):
"""Identity forward, all-reduce backward. Use before colwise layers."""
return _AllReduceBackward.apply(x, device_mesh)
def all_reduce_forward(x, device_mesh):
"""All-reduce forward, identity backward. Use after rowwise layers."""
return _AllReduceForward.apply(x, device_mesh)
def all_gather(x, device_mesh):
"""All-gather forward, split backward."""
return _AllGather.apply(x, device_mesh)
def split(x, device_mesh):
"""Split forward, all-gather backward."""
return _Split.apply(x, device_mesh)
def reduce_scatter(x, device_mesh):
"""Reduce-scatter forward, all-gather backward."""
return _ReduceScatter.apply(x, device_mesh)
def distribute_module(
module: nn.Module,
device_mesh=None,
input_fn=None,
output_fn=None,
) -> nn.Module:
"""
Copy pasted from torch's function but we remove the communications (partitioning)
as well as buffer registering that is similarly not efficient.
"""
if input_fn is not None:
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
if output_fn is not None:
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
return module
class TensorParallelLayer:
"""General tensor parallel layer for transformers"""
device_mesh = None
rank = None
empty_param = None
def __init__(self, device_mesh=None, rank=None, empty_param=None):
self.rank = rank
self.device_mesh = device_mesh
self.empty_param = empty_param
def _prepare_input_fn(self, mod, inputs, device_mesh):
raise NotImplementedError
def _prepare_output_fn(self, mod, outputs, device_mesh):
raise NotImplementedError
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
raise NotImplementedError
def prepare_module_tp(self, module: nn.Module, device_mesh, **kwargs) -> nn.Module:
distribute_module(
module,
device_mesh,
self._prepare_input_fn,
self._prepare_output_fn,
)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
"""
Compute the expected shape after TP sharding for a given full shape.
Args:
full_shape: The full (unsharded) parameter shape
Returns:
The expected sharded shape for this rank
"""
# Default: no sharding, return full shape
return tuple(full_shape)
def update_module_attributes(self, module: nn.Module):
"""
Update module attributes (e.g. in_features, out_features) to reflect sharded dimensions.
Args:
module: The module to update
Returns:
None, update the module in-place
"""
pass
class ColwiseParallel(TensorParallelLayer):
"""
Column-wise parallel: weight is sharded on dim -2 (output features).
Forward: input replicated -> output sharded on last dim.
If gather_output=True, output is all-gathered to produce full tensor.
"""
def __init__(self, gather_output: bool = False, **kwargs):
super().__init__(**kwargs)
self.gather_output = gather_output
def _prepare_input_fn(self, mod, inputs, device_mesh):
input_tensor = inputs[0] if inputs else inputs
return all_reduce_backward(input_tensor, device_mesh)
def _prepare_output_fn(self, mod, outputs, device_mesh):
if self.gather_output:
return all_gather(outputs, device_mesh)
return outputs
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, shard this one (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
return parameter.to(device=device, dtype=dtype)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
world_size = self.device_mesh.size()
shape = list(full_shape)
# Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
dim = -1 if len(shape) == 1 else -2
dim = len(shape) + dim if dim < 0 else dim
shard_size = math.ceil(shape[dim] / world_size)
start = self.rank * shard_size
end = min(start + shard_size, shape[dim])
shape[dim] = end - start
return tuple(shape)
def update_module_attributes(self, module: nn.Module):
# If we gather the output, the output dimension of the module is not sharded, so no need to update out_features.
# Otherwise, we need to update out_features to reflect the sharded dimension.
if not self.gather_output and hasattr(module, "out_features"):
module.out_features = self.get_expected_sharded_shape((module.out_features,))[0]
class ReplicatedWithGradAllReduce(TensorParallelLayer):
"""
Replicated parameter with gradient all-reduce.
For parameters like q_norm/k_norm that sit between colwise and rowwise
layers. The parameter is replicated (not sharded), but its gradient
accumulates from local heads only in TP mode. This class registers a
backward hook to all-reduce the parameter gradient.
"""
def _prepare_input_fn(self, mod, inputs, device_mesh):
return inputs
def _prepare_output_fn(self, mod, outputs, device_mesh):
return outputs
def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)
def prepare_module_tp(self, module, device_mesh, **kwargs):
# Use a module-level backward hook (not param.register_hook) because parameters are replaced during weight loading after this method runs.
# Module hooks survive parameter replacement.
def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh):
for param in mod.parameters():
if param.grad is not None:
all_reduce_forward(param.grad, mesh)
module.register_full_backward_hook(_backward_hook)
class AllReduceParallel(TensorParallelLayer):
"""
Marker layer: parameters (if any) are replicated; the forward output is all-reduced
across the TP mesh. Use as a no-op `nn.Identity` placed at a sync point after a
colwise-sharded compute that ends in a head-axis (or similar) reduction, so each
rank holds only a partial sum and needs to share it before the next dependent op
(e.g. the lightning indexer's score sum before its top-k).
"""
def _prepare_input_fn(self, mod, inputs, device_mesh):
return inputs
def _prepare_output_fn(self, mod, outputs, device_mesh):
return all_reduce_forward(outputs, device_mesh)
def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)
def prepare_module_tp(self, module, device_mesh, **kwargs):
distribute_module(module, device_mesh, output_fn=self._prepare_output_fn)
class MlaKvAProjParallel(TensorParallelLayer):
"""
For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite):
kv_a_proj_with_mqa output is [kv_lora_rank + qk_rope_head_dim] (can have different naming but important thing
to understand is that it is split)
Example below (from modeling_longcat_flash.py):
kv_a_proj_with_mqa
|
split
/ \
k_pass k_rot <-- "bypasses kv_b_proj"
| | (goes straight to attention,
kv_a_layernorm | never touches kv_b_proj)
| |
kv_b_proj |
(colwise) |
| |
k_pass k_rot
\\ /
cat
|
key_states
k_pass is passed to kv_b_proj (colwise) which has built-in all_reduce_backward so we don't have a partial gradient for it.
However, k_rot goes straight to attention, never touches kv_b_proj. So we need to average gradient across all ranks otherwise we only get gradient for one rank (partial gradient).
"""
def _prepare_output_fn(self, mod, output, device_mesh):
if not hasattr(mod.config, "qk_rope_head_dim"):
raise AttributeError(
f"Config for {type(mod).__name__} does not have `qk_rope_head_dim`. "
"MlaKvAProjParallel requires `qk_rope_head_dim` to be defined in the model config. "
"Please add it to the model's config or update the TP plan mapping."
)
rope_dim = mod.config.qk_rope_head_dim
pass_output, rope_output = output.split([output.shape[-1] - rope_dim, rope_dim], dim=-1)
rope_output = all_reduce_backward(rope_output, device_mesh)
return torch.cat([pass_output, rope_output], dim=-1)
def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None):
return param[...].to(device=device, dtype=dtype)
def prepare_module_tp(self, module, device_mesh, config=None, **kwargs):
module.config = config
distribute_module(module, device_mesh, output_fn=self._prepare_output_fn)
class RowwiseParallel(TensorParallelLayer):
"""
Row-wise parallel: weight is sharded on dim -1 (input features).
Forward: input (optionally split) -> output partial -> all-reduce to replicate.
Args:
split_input: If True, splits replicated input before matmul. Use when input
comes from a non-parallelizable operation (chunk/slice).
Default False (expects pre-sharded input from colwise layer).
"""
def __init__(self, split_input: bool = False, **kwargs):
super().__init__(**kwargs)
self.split_input = split_input
def _prepare_input_fn(self, mod, inputs, device_mesh):
if hasattr(mod, "bias") and mod.bias is not None:
mod._bias = mod.bias
mod.bias = None
input_tensor = inputs[0] if inputs else inputs
if self.split_input:
# Input is replicated, split it to match sharded weight
return split(input_tensor, device_mesh)
return input_tensor
def _prepare_output_fn(self, mod, outputs, device_mesh):
outputs = all_reduce_forward(outputs, device_mesh)
if hasattr(mod, "_bias") and mod._bias is not None:
outputs = outputs + mod._bias
return outputs
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, it should not be sharded (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = param[...]
else:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
return parameter.to(device=device, dtype=dtype)
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
# 1D tensors (bias) are NOT sharded in rowwise
if len(full_shape) == 1:
return tuple(full_shape)
world_size = self.device_mesh.size()
shape = list(full_shape)
dim = -1
dim = len(shape) + dim if dim < 0 else dim
shard_size = math.ceil(shape[dim] / world_size)
start = self.rank * shard_size
end = min(start + shard_size, shape[dim])
shape[dim] = end - start
return tuple(shape)
def update_module_attributes(self, module: nn.Module):
if hasattr(module, "in_features"):
# To fall in the 2D case in get_expected_sharded_shape,
# otherwise it will be treated as 1D and not sharded
shape = (1, module.in_features)
module.in_features = self.get_expected_sharded_shape(shape)[1]
class PackedColwiseParallel(ColwiseParallel):
"""Packed column-wise parallel for fused weights like gate_up_proj."""
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, shard this one (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
if dim < len(expected_shape):
# Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
# Use regular tensor shard - concatenation will happen after
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
else:
# Input is already packed, use packed sharding
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
return parameter.to(device=device, dtype=dtype)
class PackedRowwiseParallel(RowwiseParallel):
"""Packed row-wise parallel for fused weights like gate_up_proj."""
def shard_tensor(
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
) -> torch.Tensor:
# If only 1 dim, it should not be sharded (usually it's a `bias`)
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
if dim == 1:
parameter = param[...]
else:
# Check if input tensor is unpacked (shape mismatch with expected packed size)
# This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
if actual_dim < expected_packed_dim:
# Input is unpacked, use regular tensor shard
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
else:
# Input is already packed, use packed sharding
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
return parameter.to(device=device, dtype=dtype)
class EmbeddingParallel(TensorParallelLayer):
"""EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
super().__init__(**kwargs)
self.embedding_dim_sharding = embedding_dim_sharding
def _prepare_input_fn(self, mod, inputs, device_mesh):
input_tensor = inputs[0] if inputs else inputs
# For vocab-parallel (dim 0), we need to handle masking and offsetting
if self.embedding_dim_sharding == 0:
rank = device_mesh.get_local_rank()
# Get vocab range for this rank
# Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
# which may not be updated after sharding
per_partition_size = mod.weight.shape[0]
vocab_start_index = rank * per_partition_size
vocab_end_index = vocab_start_index + per_partition_size
# Build mask for out-of-vocabulary tokens
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
mod._input_mask = input_mask
# Offset input to local indices and mask invalid ones
masked_input = input_tensor.clone() - vocab_start_index
masked_input[input_mask] = 0 # Set to valid local index
return masked_input
return input_tensor
def _prepare_output_fn(self, mod, outputs, device_mesh):