Skip to content

Commit f539bfa

Browse files
committed
Pipeclean TP with THD, add TP unit tests.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent edaef59 commit f539bfa

6 files changed

Lines changed: 181 additions & 3 deletions

File tree

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,10 @@ def forward(
526526
if self.config.tensor_parallel:
527527
# If using TP, shard your activation across the TP group,
528528
# to support row-wise tensor parallelism in the LM head.
529+
# Use ... to support both BSHD (3D) and THD (2D) hidden states.
529530
tp_rank = self.tp_mesh.get_local_rank()
530531
tp_stride = hidden_states.shape[-1] // self.config.tp_size
531-
hidden_states = hidden_states[:, :, tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
532+
hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
532533

533534
with transformer_engine.pytorch.autocast(enabled=False):
534535
if hidden_states.ndim == 3:
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
defaults:
2+
- L0_sanity
3+
- _self_
4+
5+
tp_size: 2 # Tensor Parallel sharding factor
6+
cp_size: 1
7+
8+
use_sequence_packing: false
9+
10+
config_kwargs:
11+
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
12+
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
13+
tensor_parallel: true # Tensor Parallelism for TE
14+
sequence_parallel: false # Sequence parallelism for LayerNorm on TP ranks.
15+
tp_size: ${tp_size} # Tensor Parallel Size

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,10 @@ def forward(
532532
if self.config.tensor_parallel:
533533
# If using TP, shard your activation across the TP group,
534534
# to support row-wise tensor parallelism in the LM head.
535+
# Use ... to support both BSHD (3D) and THD (2D) hidden states.
535536
tp_rank = self.tp_mesh.get_local_rank()
536537
tp_stride = hidden_states.shape[-1] // self.config.tp_size
537-
hidden_states = hidden_states[:, :, tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
538+
hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
538539

539540
with transformer_engine.pytorch.autocast(enabled=False):
540541
if hidden_states.ndim == 3:

bionemo-recipes/recipes/llama3_native_te/tests/test_train.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,61 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
480480
assert stats_log.stat().st_size > 0, "Statistics log file is empty"
481481

482482

483+
def test_sanity_nd_parallel_tp1_bshd(tmp_path, recipe_path):
484+
"""Test ND-parallel training with tensor_parallel=True and tp_size=1 (trivial TP group), BSHD.
485+
486+
This test validates that all TP code paths in NVLlamaModel and NVLlamaForCausalLM execute
487+
correctly with a single-rank TP mesh:
488+
- parallelize_module on embed_tokens (ColwiseParallel)
489+
- TransformerLayer TP mode flags
490+
- lm_head row-parallel mode and set_tensor_parallel_group
491+
- Hidden-state activation slicing in NVLlamaForCausalLM.forward
492+
"""
493+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
494+
sanity_config = compose(
495+
config_name="L0_sanity_tp",
496+
overrides=[
497+
f"+wandb.dir={tmp_path}",
498+
f"checkpoint.ckpt_dir={tmp_path}",
499+
"num_train_steps=10",
500+
"tp_size=1",
501+
"checkpoint.resume_from_checkpoint=false",
502+
],
503+
)
504+
505+
final_loss = main_fsdp2_cp(sanity_config)
506+
gc.collect()
507+
torch.cuda.empty_cache()
508+
509+
assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite"
510+
511+
512+
def test_sanity_nd_parallel_tp1_sequence_parallel_bshd(tmp_path, recipe_path):
513+
"""Test ND-parallel training with tensor_parallel=True, sequence_parallel=True, tp_size=1, BSHD.
514+
515+
Validates that the sequence-parallel RMSNorm (set_device_mesh on the final norm) does not
516+
break forward/backward even when the TP group is a single rank.
517+
"""
518+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
519+
sanity_config = compose(
520+
config_name="L0_sanity_tp",
521+
overrides=[
522+
f"+wandb.dir={tmp_path}",
523+
f"checkpoint.ckpt_dir={tmp_path}",
524+
"num_train_steps=10",
525+
"tp_size=1",
526+
"config_kwargs.sequence_parallel=true",
527+
"checkpoint.resume_from_checkpoint=false",
528+
],
529+
)
530+
531+
final_loss = main_fsdp2_cp(sanity_config)
532+
gc.collect()
533+
torch.cuda.empty_cache()
534+
535+
assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite"
536+
537+
483538
@requires_fp8
484539
def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
485540
"""Test that FP8 stats logging works with FSDP2."""

bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,111 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path):
238238
)
239239

240240

241+
@requires_multi_gpu
242+
def test_multi_gpu_train_te_fsdp2_tp_bshd(tmp_path, recipe_path):
243+
"""Test FSDP2 with tensor parallelism on 2 GPUs using BSHD input format.
244+
245+
Validates:
246+
- The 1-D TP device mesh (dp=1, cp=1, tp=2) is created and used correctly
247+
- Embedding weights are ColwiseParallel-sharded across 2 TP ranks
248+
- TransformerLayer TP mode shards QKV/FFN weights across ranks
249+
- Row-wise parallel LM head with hidden-state slicing before forward
250+
"""
251+
run_train_cmd(
252+
[
253+
"torchrun",
254+
"--standalone",
255+
"--nproc_per_node=2",
256+
"train_fsdp2_nd_parallel.py",
257+
"--config-name",
258+
"L0_sanity_tp",
259+
"num_train_steps=10",
260+
f"checkpoint.ckpt_dir={tmp_path}",
261+
],
262+
recipe_path,
263+
)
264+
265+
266+
@requires_multi_gpu
267+
@requires_datacenter_hardware
268+
def test_multi_gpu_train_te_fsdp2_tp_thd(tmp_path, recipe_path):
269+
"""Test FSDP2 with tensor parallelism on 2 GPUs using THD (sequence-packed) input format.
270+
271+
Validates:
272+
- TP=2, CP=1 with sequence-packing / THD attention format
273+
- _unpad_input / _pad_input round-trip works alongside TP activation sharding
274+
- padding_causal mask type is compatible with row-wise parallel LM head
275+
"""
276+
run_train_cmd(
277+
[
278+
"torchrun",
279+
"--standalone",
280+
"--nproc_per_node=2",
281+
"train_fsdp2_nd_parallel.py",
282+
"--config-name",
283+
"L0_sanity_tp",
284+
"num_train_steps=10",
285+
f"checkpoint.ckpt_dir={tmp_path}",
286+
"use_sequence_packing=true",
287+
"config_kwargs.attn_input_format=thd",
288+
"config_kwargs.self_attn_mask_type=padding_causal",
289+
],
290+
recipe_path,
291+
)
292+
293+
294+
@requires_multi_gpu
295+
def test_multi_gpu_train_te_fsdp2_tp_sequence_parallel_bshd(tmp_path, recipe_path):
296+
"""Test FSDP2 with tensor parallelism + sequence parallelism on 2 GPUs, BSHD.
297+
298+
Validates that sequence parallelism (LayerNorm activations sharded across TP ranks)
299+
works alongside standard tensor parallelism without errors.
300+
"""
301+
run_train_cmd(
302+
[
303+
"torchrun",
304+
"--standalone",
305+
"--nproc_per_node=2",
306+
"train_fsdp2_nd_parallel.py",
307+
"--config-name",
308+
"L0_sanity_tp",
309+
"num_train_steps=10",
310+
f"checkpoint.ckpt_dir={tmp_path}",
311+
"config_kwargs.sequence_parallel=true",
312+
],
313+
recipe_path,
314+
)
315+
316+
317+
@requires_multi_gpu
318+
def test_multi_gpu_train_te_fsdp2_tp_bshd_with_checkpointing(tmp_path, recipe_path):
319+
"""Test FSDP2 TP training on 2 GPUs with checkpoint saving.
320+
321+
Validates:
322+
- Sharded FSDP2 checkpoints are written correctly while TP is active
323+
- The expected checkpoint directory structure is present after training
324+
"""
325+
run_train_cmd(
326+
[
327+
"torchrun",
328+
"--standalone",
329+
"--nproc_per_node=2",
330+
"train_fsdp2_nd_parallel.py",
331+
"--config-name",
332+
"L0_sanity_tp",
333+
"num_train_steps=10",
334+
f"checkpoint.ckpt_dir={tmp_path}",
335+
"checkpoint.save_every_n_steps=5",
336+
"checkpoint.resume_from_checkpoint=false",
337+
],
338+
recipe_path,
339+
)
340+
341+
ckpt_dir = tmp_path / "train_fsdp2"
342+
assert ckpt_dir.exists(), f"Checkpoint directory not created: {ckpt_dir}"
343+
assert (ckpt_dir / "step_5").exists(), "Checkpoint at step 5 not found"
344+
345+
241346
nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0
242347

243348

bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,10 @@ def forward(
532532
if self.config.tensor_parallel:
533533
# If using TP, shard your activation across the TP group,
534534
# to support row-wise tensor parallelism in the LM head.
535+
# Use ... to support both BSHD (3D) and THD (2D) hidden states.
535536
tp_rank = self.tp_mesh.get_local_rank()
536537
tp_stride = hidden_states.shape[-1] // self.config.tp_size
537-
hidden_states = hidden_states[:, :, tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
538+
hidden_states = hidden_states[..., tp_rank * tp_stride : (tp_rank + 1) * tp_stride]
538539

539540
with transformer_engine.pytorch.autocast(enabled=False):
540541
if hidden_states.ndim == 3:

0 commit comments

Comments
 (0)