3131import transformer_engine .common .recipe
3232import transformer_engine .pytorch
3333import transformers
34+ from torch .distributed .tensor .parallel import ColwiseParallel , parallelize_module
35+ from torch .distributed .tensor .placement_types import Replicate
3436from transformer_engine .pytorch .attention import InferenceParams
3537from transformer_engine .pytorch .attention .inference import PagedKVCacheManager
3638from transformer_engine .pytorch .attention .rope import RotaryPositionEmbedding
@@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig):
5860 # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5961 attn_input_format : str = "thd"
6062 self_attn_mask_type : str = "padding_causal"
63+ tensor_parallel : bool = False
64+ sequence_parallel : bool = False
65+ tp_size : int = 1
6166
6267 def __init__ (
6368 self ,
@@ -148,20 +153,26 @@ def __init__(
148153 config : LlamaConfig ,
149154 fp8_recipe : transformer_engine .common .recipe .Recipe | None = None ,
150155 fp4_recipe : transformer_engine .common .recipe .Recipe | None = None ,
156+ nvte_tp_mesh : torch .distributed .DeviceMesh | None = None ,
157+ nvte_weight_mesh : torch .distributed .DeviceMesh | None = None ,
151158 ):
152159 """Initialize the NVLlama model.
153160
154161 Args:
155162 config: The configuration of the model.
156163 fp8_recipe: The FP8 recipe for the model.
157164 fp4_recipe: The FP4 recipe for the model.
165+ nvte_tp_mesh: TP DeviceMesh for the model.
166+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
158167 """
159168 super ().__init__ (config )
160169 self .config = config
161170 self .padding_idx = config .pad_token_id
162171 self .vocab_size = config .vocab_size
163172 self ._fp8_recipe : transformer_engine .common .recipe .Recipe | None = fp8_recipe
164173 self ._fp4_recipe : transformer_engine .common .recipe .Recipe | None = fp4_recipe
174+ self .tp_mesh = nvte_tp_mesh
175+ self .weight_mesh = nvte_weight_mesh
165176
166177 if self .config .layer_precision is None :
167178 if fp8_recipe is not None and fp4_recipe is not None :
@@ -180,6 +191,31 @@ def __init__(
180191
181192 self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size , self .padding_idx , dtype = config .dtype )
182193
194+ # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
195+ if config .tensor_parallel :
196+ assert (
197+ self .tp_mesh is not None ,
198+ "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
199+ )
200+ assert (
201+ self .tp_mesh .size () == config .tp_size ,
202+ f"[NVLlamaModel] DeviceMesh TP size ({ self .tp_mesh .size ()} ) "
203+ f"does not match configured TP size ({ config .tp_size } )."
204+ )
205+ # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
206+ # during HuggingFace post-init, this will automatically convert the TELinear head
207+ # weight into a DTensor with the correct sharding placements prior to FSDP2
208+ # fully_shard(), and no need to call TELinear.set_device_mesh().
209+ parallelize_module (
210+ self .embed_tokens ,
211+ self .tp_mesh ,
212+ # Un-sharded output activations for compatible input to TETransformer.
213+ # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
214+ # RowwiseParallel doesn't support output_layouts=Replicate() with
215+ # torch.compile: https://github.com/pytorch/torchtitan/issues/534
216+ ColwiseParallel (input_layouts = Replicate (), output_layouts = Replicate ())
217+ )
218+
183219 def _init_method (x ):
184220 torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range )
185221
@@ -207,6 +243,11 @@ def _init_method(x):
207243 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
208244 init_method = _init_method ,
209245 output_layer_init_method = _init_method ,
246+ set_parallel_mode = config .tensor_parallel ,
247+ sequence_parallel = config .sequence_parallel ,
248+ tp_size = config .tp_size ,
249+ tp_mesh = self .tp_mesh ,
250+ weight_mesh = self .weight_mesh ,
210251 )
211252 ]
212253
@@ -217,6 +258,8 @@ def _init_method(x):
217258 dtype = config .dtype ,
218259 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
219260 )
261+ # Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
262+ self .norm .set_device_mesh (tp_mesh = self .tp_mesh , weight_mesh = self .weight_mesh )
220263
221264 # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
222265 # LlamaRotaryEmbedding.
@@ -393,17 +436,30 @@ def __init__(
393436 config ,
394437 fp8_recipe : transformer_engine .common .recipe .Recipe | None = None ,
395438 fp4_recipe : transformer_engine .common .recipe .Recipe | None = None ,
439+ nvte_tp_mesh : torch .distributed .DeviceMesh | None = None ,
440+ nvte_weight_mesh : torch .distributed .DeviceMesh | None = None ,
396441 ):
397442 """Initialize the NVLlamaForCausalLM model.
398443
399444 Args:
400445 config: The configuration of the model.
401446 fp8_recipe: The FP8 recipe for the model.
402447 fp4_recipe: The FP4 recipe for the model.
448+ nvte_tp_mesh: TP DeviceMesh for the model.
449+ nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
403450 """
404451 super ().__init__ (config )
405- self .model = NVLlamaModel (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
452+ self .model = NVLlamaModel (
453+ config ,
454+ fp8_recipe = fp8_recipe ,
455+ fp4_recipe = fp4_recipe ,
456+ nvte_tp_mesh = nvte_tp_mesh ,
457+ nvte_weight_mesh = nvte_weight_mesh ,
458+ )
459+ self .config = config
406460 self .vocab_size = config .vocab_size
461+ self .tp_mesh = nvte_tp_mesh
462+ self .weight_mesh = nvte_weight_mesh
407463 with transformer_engine .pytorch .quantized_model_init (enabled = False ):
408464 self .lm_head = transformer_engine .pytorch .Linear (
409465 config .hidden_size ,
@@ -412,9 +468,19 @@ def __init__(
412468 params_dtype = config .dtype ,
413469 device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
414470 init_method = lambda x : torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range ),
471+ parallel_mode = "row" if config .tensor_parallel else None ,
472+ # This scatters your output, not ever needed for final layer.
473+ # Will all-reduce the output instead, as required.
474+ sequence_parallel = False ,
475+ tp_size = config .tp_size ,
415476 )
477+ if config .tensor_parallel :
478+ # If using tensor parallelism, the head weights have already been tied
479+ # to the embedding weights. Just set the tensor parallel group for TE.
480+ # No parameter quantization either, so no need for weight_mesh.
481+ self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
416482
417- # Initialize weights and apply final processing
483+ # Initialize weights and apply final processing. Ties weights.
418484 self .post_init ()
419485
420486 def forward (
@@ -467,6 +533,13 @@ def forward(
467533 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468534 slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
469535
536+ if self .config .tensor_parallel :
537+ # If using TP, shard your activation across the TP group,
538+ # to support row-wise tensor parallelism in the LM head.
539+ tp_rank = self .tp_mesh .get_local_rank ()
540+ tp_stride = hidden_states .shape [- 1 ] // self .config .tp_size
541+ hidden_states = hidden_states [:, :, tp_rank * tp_stride :(tp_rank + 1 )* tp_stride ]
542+
470543 with transformer_engine .pytorch .autocast (enabled = False ):
471544 if hidden_states .ndim == 3 :
472545 logits = self .lm_head (hidden_states [:, slice_indices , :])
0 commit comments