Skip to content

Commit 2759c62

Browse files
committed
Enable TransformerEngine-backed Tensor Parallelism with Llama3.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 46112e7 commit 2759c62

5 files changed

Lines changed: 165 additions & 24 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/checkpoint.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,28 @@ class AppState(Stateful):
233233
epoch: int = 0
234234

235235
def state_dict(self):
236-
"""Get the state dict for the model, optimizer, scheduler, and step."""
236+
"""
237+
Get the state dict for the model, optimizer, scheduler, and step.
238+
This factory both retrieves the model state dictionary when saving
239+
checkpoints and initializes a destination for the state read from
240+
DCP checkpoint files when loading checkpoints.
241+
"""
237242
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
238-
model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")}
243+
for fqn in list(model_state_dict.keys()):
244+
# Get the model parameter.
245+
model_param = model_state_dict[fqn]
246+
if isinstance(model_param, DTensor):
247+
model_param = model_param.to_local()
248+
if model_param.numel() == 0 and fqn in optimizer_state_dict['state']:
249+
# Empty model parameter. Clear the associated optimizer state
250+
# when initializing the optimizer state upon DCP load, because
251+
# empty optimizer state DTensors are not checkpointed with DCP,
252+
# yet get_state_dict / _init_optim_state produce empty Tensors.
253+
# TransformerEngine uses empty Tensors for dummy Parameters.
254+
optimizer_state_dict['state'][fqn] = {}
255+
if fqn.endswith("_extra_state"):
256+
# Evict `_extra_state` quantization data from model checkpoint.
257+
model_state_dict.pop(fqn)
239258
return {
240259
"model": model_state_dict,
241260
"optim": optimizer_state_dict,
@@ -245,12 +264,18 @@ def state_dict(self):
245264
}
246265

247266
def load_state_dict(self, state_dict: dict):
248-
"""Load the state dict for the model, optimizer, scheduler, and step."""
267+
"""
268+
Load the state dict for the model, optimizer, scheduler, and step.
269+
Given the checkpoint-loaded state_dict, set the state of the model,
270+
optimizer, scheduler, step, and epoch to the values in state_dict.
271+
"""
249272
set_state_dict(
250273
self.model,
251274
self.optimizer,
252275
model_state_dict=state_dict["model"],
253276
optim_state_dict=state_dict["optim"],
277+
# Non-strict checkpoint loading ignores empty optimizer states,
278+
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
254279
options=StateDictOptions(strict=False),
255280
)
256281
self.scheduler.load_state_dict(state_dict["scheduler"])
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
defaults:
2+
- L0_sanity
3+
- _self_
4+
5+
tp_size: 2 # TP Sharding
6+
cp_size: 2 # FSDP-CP Sharding
7+
8+
dataset:
9+
# CP2 * (8 for FP8 Activations, 16 for FP8 Parameters)
10+
pad_sequences_to_be_divisible_by: 32
11+
12+
fp8_config:
13+
enabled: true
14+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
15+
fp8_format: "HYBRID"
16+
fp8_recipe_kwargs: {}
17+
18+
checkpoint:
19+
ckpt_dir: ./fsdp_nd_ckpts
20+
save_final_model: true
21+
22+
config_kwargs:
23+
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
24+
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
25+
tensor_parallel: true # Tensor Parallelism for TE
26+
sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks.
27+
tp_size: ${tp_size} # Tensor Parallel Size

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ use_meta_device: true
1212
# We leave this off by default since we don't see much of a performance improvement with TE layers.
1313
use_torch_compile: false
1414

15+
# Default parallelism sizes.
16+
tp_size: 1
17+
cp_size: 1
18+
1519
use_sequence_packing: false
1620

1721
dataset:

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import transformer_engine.common.recipe
3232
import transformer_engine.pytorch
3333
import transformers
34+
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
35+
from torch.distributed.tensor.placement_types import Replicate
3436
from transformer_engine.pytorch.attention import InferenceParams
3537
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
3638
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -58,6 +60,9 @@ class NVLlamaConfig(LlamaConfig):
5860
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5961
attn_input_format: str = "thd"
6062
self_attn_mask_type: str = "padding_causal"
63+
tensor_parallel: bool = False
64+
sequence_parallel: bool = False
65+
tp_size: int = 1
6166

6267
def __init__(
6368
self,
@@ -148,20 +153,26 @@ def __init__(
148153
config: LlamaConfig,
149154
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
150155
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
156+
nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
157+
nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
151158
):
152159
"""Initialize the NVLlama model.
153160
154161
Args:
155162
config: The configuration of the model.
156163
fp8_recipe: The FP8 recipe for the model.
157164
fp4_recipe: The FP4 recipe for the model.
165+
nvte_tp_mesh: TP DeviceMesh for the model.
166+
nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
158167
"""
159168
super().__init__(config)
160169
self.config = config
161170
self.padding_idx = config.pad_token_id
162171
self.vocab_size = config.vocab_size
163172
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
164173
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
174+
self.tp_mesh = nvte_tp_mesh
175+
self.weight_mesh = nvte_weight_mesh
165176

166177
if self.config.layer_precision is None:
167178
if fp8_recipe is not None and fp4_recipe is not None:
@@ -180,6 +191,31 @@ def __init__(
180191

181192
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
182193

194+
# Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
195+
if config.tensor_parallel:
196+
assert (
197+
self.tp_mesh is not None,
198+
"[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
199+
)
200+
assert (
201+
self.tp_mesh.size() == config.tp_size,
202+
f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
203+
f"does not match configured TP size ({config.tp_size})."
204+
)
205+
# NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
206+
# during HuggingFace post-init, this will automatically convert the TELinear head
207+
# weight into a DTensor with the correct sharding placements prior to FSDP2
208+
# fully_shard(), and no need to call TELinear.set_device_mesh().
209+
parallelize_module(
210+
self.embed_tokens,
211+
self.tp_mesh,
212+
# Un-sharded output activations for compatible input to TETransformer.
213+
# NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
214+
# RowwiseParallel doesn't support output_layouts=Replicate() with
215+
# torch.compile: https://github.com/pytorch/torchtitan/issues/534
216+
ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate())
217+
)
218+
183219
def _init_method(x):
184220
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
185221

@@ -207,6 +243,11 @@ def _init_method(x):
207243
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
208244
init_method=_init_method,
209245
output_layer_init_method=_init_method,
246+
set_parallel_mode=config.tensor_parallel,
247+
sequence_parallel=config.sequence_parallel,
248+
tp_size=config.tp_size,
249+
tp_mesh=self.tp_mesh,
250+
weight_mesh=self.weight_mesh,
210251
)
211252
]
212253

@@ -217,6 +258,8 @@ def _init_method(x):
217258
dtype=config.dtype,
218259
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
219260
)
261+
# Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
262+
self.norm.set_device_mesh(tp_mesh=self.tp_mesh, weight_mesh=self.weight_mesh)
220263

221264
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
222265
# LlamaRotaryEmbedding.
@@ -393,17 +436,30 @@ def __init__(
393436
config,
394437
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
395438
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
439+
nvte_tp_mesh: torch.distributed.DeviceMesh | None = None,
440+
nvte_weight_mesh: torch.distributed.DeviceMesh | None = None,
396441
):
397442
"""Initialize the NVLlamaForCausalLM model.
398443
399444
Args:
400445
config: The configuration of the model.
401446
fp8_recipe: The FP8 recipe for the model.
402447
fp4_recipe: The FP4 recipe for the model.
448+
nvte_tp_mesh: TP DeviceMesh for the model.
449+
nvte_weight_mesh: Weight-sharding DeviceMesh for the model.
403450
"""
404451
super().__init__(config)
405-
self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
452+
self.model = NVLlamaModel(
453+
config,
454+
fp8_recipe=fp8_recipe,
455+
fp4_recipe=fp4_recipe,
456+
nvte_tp_mesh=nvte_tp_mesh,
457+
nvte_weight_mesh=nvte_weight_mesh,
458+
)
459+
self.config = config
406460
self.vocab_size = config.vocab_size
461+
self.tp_mesh = nvte_tp_mesh
462+
self.weight_mesh = nvte_weight_mesh
407463
with transformer_engine.pytorch.quantized_model_init(enabled=False):
408464
self.lm_head = transformer_engine.pytorch.Linear(
409465
config.hidden_size,
@@ -412,9 +468,19 @@ def __init__(
412468
params_dtype=config.dtype,
413469
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
414470
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
471+
parallel_mode="row" if config.tensor_parallel else None,
472+
# This scatters your output, not ever needed for final layer.
473+
# Will all-reduce the output instead, as required.
474+
sequence_parallel=False,
475+
tp_size=config.tp_size,
415476
)
477+
if config.tensor_parallel:
478+
# If using tensor parallelism, the head weights have already been tied
479+
# to the embedding weights. Just set the tensor parallel group for TE.
480+
# No parameter quantization either, so no need for weight_mesh.
481+
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
416482

417-
# Initialize weights and apply final processing
483+
# Initialize weights and apply final processing. Ties weights.
418484
self.post_init()
419485

420486
def forward(
@@ -467,6 +533,13 @@ def forward(
467533
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
468534
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
469535

536+
if self.config.tensor_parallel:
537+
# If using TP, shard your activation across the TP group,
538+
# to support row-wise tensor parallelism in the LM head.
539+
tp_rank = self.tp_mesh.get_local_rank()
540+
tp_stride = hidden_states.shape[-1] // self.config.tp_size
541+
hidden_states = hidden_states[:, :, tp_rank*tp_stride:(tp_rank + 1)*tp_stride]
542+
470543
with transformer_engine.pytorch.autocast(enabled=False):
471544
if hidden_states.ndim == 3:
472545
logits = self.lm_head(hidden_states[:, slice_indices, :])

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py renamed to bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""FSDP2 with Context Parallelism training script for Llama 3 with TransformerEngine.
16+
"""FSDP2 with Tensor & Context Parallelism training script for Llama 3 with TransformerEngine.
1717
18-
Combines Fully Sharded Data Parallel v2 with Context Parallelism (CP), where each sequence is
19-
split across multiple GPUs along the sequence dimension. This is useful for training with very long
20-
sequences that do not fit into a single GPU's memory even with FSDP2 alone. Only supports
21-
TE-accelerated models (NVLlamaForCausalLM).
18+
Combines Fully Sharded Data Parallel v2 with Tensor Parallelism (TP) and Context Parallelism (CP).
19+
In Context Parallelism, each sequence is split across multiple GPUs along the sequence dimension,
20+
which is useful for training on extremely long sequences that exhaust activation memory.
21+
In Tensor Parallelism, weights and activations are sharded on the hidden dim across multiple GPUs,
22+
which is useful for sharding model weights and activations unlike FSDP which only shards weights.
23+
Only supports TE-accelerated models (NVLlamaForCausalLM).
2224
23-
For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead.
25+
For standard FSDP2 training without N-D parallelism, use ``train_fsdp2.py`` instead.
2426
"""
2527

2628
import gc
@@ -59,7 +61,7 @@
5961

6062
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
6163
def main(args: DictConfig) -> float | None:
62-
"""Train Llama3 with TE layers using FSDP2 with Context Parallelism.
64+
"""Train Llama3 with TE layers using FSDP2, CP, and TP.
6365
6466
Returns:
6567
float: The loss value for the final batch.
@@ -73,8 +75,8 @@ def main(args: DictConfig) -> float | None:
7375

7476
device_mesh = init_device_mesh(
7577
"cuda",
76-
mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size),
77-
mesh_dim_names=("dp", "cp"),
78+
mesh_shape=(dist_config.world_size // (args.cp_size * args.tp_size), args.cp_size, args.tp_size),
79+
mesh_dim_names=("dp", "cp", "tp"),
7880
)
7981
logger.info("Created device mesh: %s", device_mesh)
8082

@@ -94,11 +96,20 @@ def main(args: DictConfig) -> float | None:
9496
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
9597

9698
with torch.device("meta") if args.use_meta_device else nullcontext():
97-
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
99+
model = NVLlamaForCausalLM(
100+
config,
101+
fp8_recipe=fp8_recipe,
102+
fp4_recipe=fp4_recipe,
103+
nvte_tp_mesh=device_mesh["tp"] if config.tensor_parallel else None,
104+
# nvte_weight_mesh is only required for Float8CurrentScaling parameters.
105+
nvte_weight_mesh=device_mesh["dp", "cp", "tp"]._flatten("weight_mesh") if config.tensor_parallel else None,
106+
)
98107

99108
logger.info("Initialized Model:\n%s", model)
100109

101110
# --- Distributed Wrapping (FSDP2 + CP) ---
111+
112+
# Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks.
102113
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
103114

104115
# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
@@ -107,7 +118,7 @@ def main(args: DictConfig) -> float | None:
107118
fully_shard(layer, mesh=cp_dp_mesh)
108119
fully_shard(model, mesh=cp_dp_mesh)
109120

110-
# Attach the CP group to the model.
121+
# Attach the CP ProcessGroup to the TransformerEngine model.
111122
for layer in model.model.layers:
112123
layer.set_context_parallel_group(
113124
device_mesh["cp"].get_group(),
@@ -136,9 +147,12 @@ def main(args: DictConfig) -> float | None:
136147
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
137148
OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)
138149

139-
# We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
140-
# ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
141-
if device_mesh["cp"].get_local_rank() == 0:
150+
# We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks.
151+
# This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
152+
cp_tp_mesh = device_mesh["cp", "tp"]._flatten(mesh_dim_name="cp_tp")
153+
if cp_tp_mesh.get_local_rank() == 0:
154+
# We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper
155+
# that will shard, replicate, and distribute the data across the flattened CP and TP group.
142156
if args.use_sequence_packing:
143157
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
144158
else:
@@ -155,8 +169,8 @@ def main(args: DictConfig) -> float | None:
155169
train_dataloader = None
156170
dataset_or_sampler = None
157171

158-
# On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0.
159-
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"])
172+
# Deliver CP-sharded replicates to a flattened CP-TP mesh.
173+
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_tp_mesh)
160174

161175
# --- Checkpoint Resume ---
162176
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
@@ -169,7 +183,6 @@ def main(args: DictConfig) -> float | None:
169183
ckpt_path=ckpt_path,
170184
dist_config=dist_config,
171185
dataloader=train_dataloader,
172-
process_group=cp_dp_mesh.get_group(),
173186
)
174187
logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
175188
else:
@@ -234,7 +247,6 @@ def main(args: DictConfig) -> float | None:
234247
epoch=epoch,
235248
dist_config=dist_config,
236249
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
237-
process_group=cp_dp_mesh.get_group(),
238250
max_checkpoints=args.checkpoint.max_checkpoints,
239251
async_save=args.checkpoint.async_save,
240252
)
@@ -267,4 +279,4 @@ def main(args: DictConfig) -> float | None:
267279

268280

269281
if __name__ == "__main__":
270-
main()
282+
main()

0 commit comments

Comments
 (0)