3131import transformer_engine .common .recipe
3232import transformer_engine .pytorch
3333import transformers
34+ from torch .distributed .tensor .parallel import ColwiseParallel , parallelize_module
35+ from torch .distributed .tensor .placement_types import Replicate
3436from transformer_engine .pytorch .attention import InferenceParams
3537from transformer_engine .pytorch .attention .inference import PagedKVCacheManager
3638from transformer_engine .pytorch .attention .rope import RotaryPositionEmbedding
@@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig):
5860 # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5961 attn_input_format : str = "thd"
6062 self_attn_mask_type : str = "padding_causal"
63+ tensor_parallel : bool = False
64+ sequence_parallel : bool = False
65+ tp_size : int = 1
6166
6267 def __init__ (
6368 self ,
@@ -148,20 +153,26 @@ def __init__(
148153 config : LlamaConfig ,
149154 fp8_recipe : transformer_engine .common .recipe .Recipe | None = None ,
150155 fp4_recipe : transformer_engine .common .recipe .Recipe | None = None ,
156+ nvte_tp_mesh : torch .distributed .DeviceMesh | None = None ,
157+ nvte_weight_mesh : torch .distributed .DeviceMesh | None = None ,
151158 ):
152159 """Initialize the NVLlama model.
153160
154161 Args:
155162 config: The configuration of the model.
156163 fp8_recipe: The FP8 recipe for the model.
157164 fp4_recipe: The FP4 recipe for the model.
165+ nvte_tp_mesh: TP DeviceMesh for the model.
166+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
158167 """
159168 super ().__init__ (config )
160169 self .config = config
161170 self .padding_idx = config .pad_token_id
162171 self .vocab_size = config .vocab_size
163172 self ._fp8_recipe : transformer_engine .common .recipe .Recipe | None = fp8_recipe
164173 self ._fp4_recipe : transformer_engine .common .recipe .Recipe | None = fp4_recipe
174+ self .tp_mesh = nvte_tp_mesh
175+ self .weight_mesh = nvte_weight_mesh
165176
166177 if self .config .layer_precision is None :
167178 if fp8_recipe is not None and fp4_recipe is not None :
@@ -180,6 +191,27 @@ def __init__(
180191
181192 self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size , self .padding_idx , dtype = config .dtype )
182193
194+ # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
195+ if config .tensor_parallel :
196+ assert self .tp_mesh is not None , "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
197+ assert self .tp_mesh .size () == config .tp_size , (
198+ f"[NVLlamaModel] DeviceMesh TP size ({ self .tp_mesh .size ()} ) "
199+ f"does not match configured TP size ({ config .tp_size } )." ,
200+ )
201+ # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
202+ # during HuggingFace post-init, this will automatically convert the TELinear head
203+ # weight into a DTensor with the correct sharding placements prior to FSDP2
204+ # fully_shard(), and no need to call TELinear.set_device_mesh().
205+ parallelize_module (
206+ self .embed_tokens ,
207+ self .tp_mesh ,
208+ # Un-sharded output activations for compatible input to TETransformer.
209+ # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
210+ # RowwiseParallel doesn't support output_layouts=Replicate() with
211+ # torch.compile: https://github.com/pytorch/torchtitan/issues/534
212+ ColwiseParallel (input_layouts = Replicate (), output_layouts = Replicate ()),
213+ )
214+
183215 def _init_method (x ):
184216 torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range )
185217
@@ -207,6 +239,11 @@ def _init_method(x):
207239 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
208240 init_method = _init_method ,
209241 output_layer_init_method = _init_method ,
242+ set_parallel_mode = config .tensor_parallel ,
243+ sequence_parallel = config .sequence_parallel ,
244+ tp_size = config .tp_size ,
245+ tp_mesh = self .tp_mesh ,
246+ weight_mesh = self .weight_mesh ,
210247 )
211248 ]
212249
@@ -217,6 +254,8 @@ def _init_method(x):
217254 dtype = config .dtype ,
218255 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
219256 )
257+ # Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
258+ self .norm .set_device_mesh (tp_mesh = self .tp_mesh , weight_mesh = self .weight_mesh )
220259
221260 # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
222261 # LlamaRotaryEmbedding.
@@ -393,17 +432,30 @@ def __init__(
393432 config ,
394433 fp8_recipe : transformer_engine .common .recipe .Recipe | None = None ,
395434 fp4_recipe : transformer_engine .common .recipe .Recipe | None = None ,
435+ nvte_tp_mesh : torch .distributed .DeviceMesh | None = None ,
436+ nvte_weight_mesh : torch .distributed .DeviceMesh | None = None ,
396437 ):
397438 """Initialize the NVLlamaForCausalLM model.
398439
399440 Args:
400441 config: The configuration of the model.
401442 fp8_recipe: The FP8 recipe for the model.
402443 fp4_recipe: The FP4 recipe for the model.
444+ nvte_tp_mesh: TP DeviceMesh for the model.
445+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
403446 """
404447 super ().__init__ (config )
405- self .model = NVLlamaModel (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
448+ self .model = NVLlamaModel (
449+ config ,
450+ fp8_recipe = fp8_recipe ,
451+ fp4_recipe = fp4_recipe ,
452+ nvte_tp_mesh = nvte_tp_mesh ,
453+ nvte_weight_mesh = nvte_weight_mesh ,
454+ )
455+ self .config = config
406456 self .vocab_size = config .vocab_size
457+ self .tp_mesh = nvte_tp_mesh
458+ self .weight_mesh = nvte_weight_mesh
407459 with transformer_engine .pytorch .quantized_model_init (enabled = False ):
408460 self .lm_head = transformer_engine .pytorch .Linear (
409461 config .hidden_size ,
@@ -412,9 +464,19 @@ def __init__(
412464 params_dtype = config .dtype ,
413465 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
414466 init_method = lambda x : torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range ),
467+ parallel_mode = "row" if config .tensor_parallel else None ,
468+ # This scatters your output, not ever needed for final layer.
469+ # Will all-reduce the output instead, as required.
470+ sequence_parallel = False ,
471+ tp_size = config .tp_size ,
415472 )
473+ if config .tensor_parallel :
474+ # If using tensor parallelism, the head weights have already been tied
475+ # to the embedding weights. Just set the tensor parallel group for TE.
476+ # No parameter quantization either, so no need for weight_mesh.
477+ self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
416478
417- # Initialize weights and apply final processing
479+ # Initialize weights and apply final processing. Ties weights.
418480 self .post_init ()
419481
420482 def forward (
@@ -467,6 +529,13 @@ def forward(
467529 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468530 slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
469531
532+ if self .config .tensor_parallel :
533+ # If using TP, shard your activation across the TP group,
534+ # to support row-wise tensor parallelism in the LM head.
535+ tp_rank = self .tp_mesh .get_local_rank ()
536+ tp_stride = hidden_states .shape [- 1 ] // self .config .tp_size
537+ hidden_states = hidden_states [:, :, tp_rank * tp_stride : (tp_rank + 1 ) * tp_stride ]
538+
470539 with transformer_engine .pytorch .autocast (enabled = False ):
471540 if hidden_states .ndim == 3 :
472541 logits = self .lm_head (hidden_states [:, slice_indices , :])
0 commit comments