Skip to content

Commit b8c36fb

Browse files
committed
Update OG2 Llama model.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent dfa9c64 commit b8c36fb

1 file changed

Lines changed: 71 additions & 2 deletions

File tree

bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import transformer_engine.common.recipe
3232
import transformer_engine.pytorch
3333
import transformers
34+
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
35+
from torch.distributed.tensor.placement_types import Replicate
3436
from transformer_engine.pytorch.attention import InferenceParams
3537
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
3638
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig):
5860
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5961
attn_input_format: str = "thd"
6062
self_attn_mask_type: str = "padding_causal"
63+
tensor_parallel: bool = False
64+
sequence_parallel: bool = False
65+
tp_size: int = 1
6166

6267
def __init__(
6368
self,
@@ -148,20 +153,26 @@ def __init__(
148153
config: LlamaConfig,
149154
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
150155
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
156+
nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
157+
nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
151158
):
152159
"""Initialize the NVLlama model.
153160
154161
Args:
155162
config: The configuration of the model.
156163
fp8_recipe: The FP8 recipe for the model.
157164
fp4_recipe: The FP4 recipe for the model.
165+
nvte_tp_mesh: TP DeviceMesh for the model.
166+
nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
158167
"""
159168
super().__init__(config)
160169
self.config = config
161170
self.padding_idx = config.pad_token_id
162171
self.vocab_size = config.vocab_size
163172
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
164173
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
174+
self.tp_mesh = nvte_tp_mesh
175+
self.weight_mesh = nvte_weight_mesh
165176

166177
if self.config.layer_precision is None:
167178
if fp8_recipe is not None and fp4_recipe is not None:
@@ -180,6 +191,27 @@ def __init__(
180191

181192
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
182193

194+
# Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
195+
if config.tensor_parallel:
196+
assert self.tp_mesh is not None, "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
197+
assert self.tp_mesh.size() == config.tp_size, (
198+
f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
199+
f"does not match configured TP size ({config.tp_size}).",
200+
)
201+
# NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
202+
# during HuggingFace post-init, this will automatically convert the TELinear head
203+
# weight into a DTensor with the correct sharding placements prior to FSDP2
204+
# fully_shard(), and no need to call TELinear.set_device_mesh().
205+
parallelize_module(
206+
self.embed_tokens,
207+
self.tp_mesh,
208+
# Un-sharded output activations for compatible input to TETransformer.
209+
# NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
210+
# RowwiseParallel doesn't support output_layouts=Replicate() with
211+
# torch.compile: https://github.com/pytorch/torchtitan/issues/534
212+
ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
213+
)
214+
183215
def _init_method(x):
184216
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
185217

@@ -207,6 +239,11 @@ def _init_method(x):
207239
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
208240
init_method=_init_method,
209241
output_layer_init_method=_init_method,
242+
set_parallel_mode=config.tensor_parallel,
243+
sequence_parallel=config.sequence_parallel,
244+
tp_size=config.tp_size,
245+
tp_mesh=self.tp_mesh,
246+
weight_mesh=self.weight_mesh,
210247
)
211248
]
212249

@@ -217,6 +254,8 @@ def _init_method(x):
217254
dtype=config.dtype,
218255
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
219256
)
257+
# Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
258+
self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh)
220259

221260
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
222261
# LlamaRotaryEmbedding.
@@ -393,17 +432,30 @@ def __init__(
393432
config,
394433
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
395434
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
435+
nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
436+
nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
396437
):
397438
"""Initialize the NVLlamaForCausalLM model.
398439
399440
Args:
400441
config: The configuration of the model.
401442
fp8_recipe: The FP8 recipe for the model.
402443
fp4_recipe: The FP4 recipe for the model.
444+
nvte_tp_mesh: TP DeviceMesh for the model.
445+
nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
403446
"""
404447
super().__init__(config)
405-
self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
448+
self.model = NVLlamaModel(
449+
config,
450+
fp8_recipe=fp8_recipe,
451+
fp4_recipe=fp4_recipe,
452+
nvte_tp_mesh=nvte_tp_mesh,
453+
nvte_weight_mesh=nvte_weight_mesh,
454+
)
455+
self.config = config
406456
self.vocab_size = config.vocab_size
457+
self.tp_mesh = nvte_tp_mesh
458+
self.weight_mesh = nvte_weight_mesh
407459
with transformer_engine.pytorch.quantized_model_init(enabled=False):
408460
self.lm_head = transformer_engine.pytorch.Linear(
409461
config.hidden_size,
@@ -412,9 +464,19 @@ def __init__(
412464
params_dtype=config.dtype,
413465
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
414466
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
467+
parallel_mode="row" if config.tensor_parallel else None,
468+
# This scatters your output, not ever needed for final layer.
469+
# Will all-reduce the output instead, as required.
470+
sequence_parallel=False,
471+
tp_size=config.tp_size,
415472
)
473+
if config.tensor_parallel:
474+
# If using tensor parallelism, the head weights have already been tied
475+
# to the embedding weights. Just set the tensor parallel group for TE.
476+
# No parameter quantization either, so no need for weight_mesh.
477+
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
416478

417-
# Initialize weights and apply final processing
479+
# Initialize weights and apply final processing. Ties weights.
418480
self.post_init()
419481

420482
def forward(
@@ -467,6 +529,13 @@ def forward(
467529
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468530
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
469531

532+
if self.config.tensor_parallel:
533+
# If using TP, shard your activation across the TP group,
534+
# to support row-wise tensor parallelism in the LM head.
535+
tp_rank = self.tp_mesh.get_local_rank()
536+
tp_stride = hidden_states.shape[-1] // self.config.tp_size
537+
hidden_states = hidden_states[:, :, tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
538+
470539
with transformer_engine.pytorch.autocast(enabled=False):
471540
if hidden_states.ndim == 3:
472541
logits = self.lm_head(hidden_states[:, slice_indices, :])

0 commit comments

Comments
 (0)