Skip to content

Commit 295ed82

Browse files
feat: FSDP2 w weight prefetching and async TP optimization
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent b9a2154 commit 295ed82

35 files changed

Lines changed: 1237 additions & 1884 deletions

examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark.yaml

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ recipe: TrainFinetuneRecipeForNextTokenPrediction
88

99
seed: 42
1010

11-
# NEW: Add benchmark section
1211
benchmark:
1312
warmup_steps: 5
1413
peak_tflops: 989 # H100: 989, A100: 312
@@ -19,7 +18,7 @@ benchmark:
1918

2019
step_scheduler:
2120
global_batch_size: 32
22-
local_batch_size: 4
21+
local_batch_size: 2
2322
ckpt_every_steps: 50
2423
val_every_steps: 1000
2524
max_steps: 10
@@ -53,21 +52,17 @@ checkpoint:
5352

5453
distributed:
5554
strategy: fsdp2
56-
dp_size: none
55+
dp_size: null
5756
tp_size: 2
5857
cp_size: 1
59-
pp_size: 4
6058

61-
sequence_parallel: false
59+
sequence_parallel: true
6260
activation_checkpointing: true
63-
64-
pipeline:
65-
pp_schedule: interleaved1f1b
66-
pp_microbatch_size: 1
67-
layers_per_stage: 2
68-
scale_grads_in_schedule: false
69-
round_virtual_stages_to_pp_multiple: up
70-
dtype: bf16
61+
enable_async_tensor_parallel: true
62+
enable_fsdp2_prefetch: true
63+
enable_compile: true
64+
defer_fsdp_grad_sync: false
65+
patch_is_packed_sequence: true # Patch transformers._is_packed_sequence to always return False: removes CPU-GPU sync per attention layer and ensures static shapes for torch.compile. Safe for non-packed (standard) training only.
7166

7267
loss_fn:
7368
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
@@ -82,7 +77,6 @@ dataset:
8277
dataloader:
8378
_target_: torch.utils.data.DataLoader
8479
batch_size: null # Dataset already yields batches
85-
# Note: model_config will be auto-injected by train_ft.py for PP models
8680

8781
optimizer:
8882
_target_: torch.optim.Adam

examples/llm_finetune/llama3_3/custom_llama3_3_70b_instruct_peft_benchmark_2nodes.yaml

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ benchmark:
1919

2020
step_scheduler:
2121
global_batch_size: 32
22-
local_batch_size: 8
22+
local_batch_size: 2
2323
ckpt_every_steps: 50
2424
val_every_steps: 1000
2525
max_steps: 10
2626

2727
dist_env:
2828
backend: nccl
29-
timeout_minutes: 1
29+
timeout_minutes: 10
3030

3131
rng:
3232
_target_: nemo_automodel.components.training.rng.StatefulRNG
@@ -53,21 +53,18 @@ checkpoint:
5353

5454
distributed:
5555
strategy: fsdp2
56-
dp_size: 2
56+
dp_size: null
5757
tp_size: 2
5858
cp_size: 1
59-
pp_size: 4
59+
pp_size: 1
6060

61-
sequence_parallel: false
61+
sequence_parallel: true
6262
activation_checkpointing: true
63-
64-
pipeline:
65-
pp_schedule: interleaved1f1b
66-
pp_microbatch_size: 1
67-
layers_per_stage: 2
68-
scale_grads_in_schedule: false
69-
round_virtual_stages_to_pp_multiple: up
70-
dtype: bf16
63+
enable_async_tensor_parallel: true
64+
enable_fsdp2_prefetch: true
65+
enable_compile: true
66+
defer_fsdp_grad_sync: false
67+
patch_is_packed_sequence: true # Patch transformers._is_packed_sequence to always return False: removes CPU-GPU sync per attention layer and ensures static shapes for torch.compile. Safe for non-packed (standard) training only.
7168

7269
loss_fn:
7370
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
@@ -82,7 +79,6 @@ dataset:
8279
dataloader:
8380
_target_: torch.utils.data.DataLoader
8481
batch_size: null # Dataset already yields batches
85-
# Note: model_config will be auto-injected by train_ft.py for PP models
8682

8783
optimizer:
8884
_target_: torch.optim.Adam

examples/llm_finetune/qwen/custom_qwen2_5_32b_peft_benchmark.yaml

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ benchmark:
1919

2020
step_scheduler:
2121
global_batch_size: 32
22-
local_batch_size: 8
22+
local_batch_size: 2
2323
ckpt_every_steps: 50
2424
val_every_steps: 1000
2525
max_steps: 10
2626

2727
dist_env:
2828
backend: nccl
29-
timeout_minutes: 1
29+
timeout_minutes: 10
3030

3131
rng:
3232
_target_: nemo_automodel.components.training.rng.StatefulRNG
@@ -53,21 +53,17 @@ checkpoint:
5353

5454
distributed:
5555
strategy: fsdp2
56-
dp_size: none
57-
tp_size: 1
56+
dp_size: null
57+
tp_size: 2
5858
cp_size: 1
59-
pp_size: 4
59+
pp_size: 1
6060

6161
sequence_parallel: false
6262
activation_checkpointing: true
63-
64-
pipeline:
65-
pp_schedule: interleaved1f1b
66-
pp_microbatch_size: 1
67-
layers_per_stage: 2
68-
scale_grads_in_schedule: false
69-
round_virtual_stages_to_pp_multiple: up
70-
dtype: bf16
63+
enable_async_tensor_parallel: false
64+
enable_fsdp2_prefetch: true
65+
enable_compile: true
66+
patch_is_packed_sequence: true # Patch transformers._is_packed_sequence to always return False: removes CPU-GPU sync per attention layer and ensures static shapes for torch.compile. Safe for non-packed (standard) training only.
7167

7268
loss_fn:
7369
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
@@ -82,7 +78,6 @@ dataset:
8278
dataloader:
8379
_target_: torch.utils.data.DataLoader
8480
batch_size: null # Dataset already yields batches
85-
# Note: model_config will be auto-injected by train_ft.py for PP models
8681

8782
optimizer:
8883
_target_: torch.optim.Adam

examples/llm_finetune/qwen/qwen2_5_32b_peft_benchmark_2nodes.yaml renamed to examples/llm_finetune/qwen/custom_qwen2_5_32b_peft_benchmark_2nodes.yaml

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
recipe: TrainFinetuneRecipeForNextTokenPrediction
1010

11-
seed: 42
12-
11+
seed: 42
12+
1313
# Benchmark section
1414
benchmark:
1515
warmup_steps: 5
@@ -21,14 +21,14 @@ benchmark:
2121

2222
step_scheduler:
2323
global_batch_size: 32
24-
local_batch_size: 8
24+
local_batch_size: 2
2525
ckpt_every_steps: 50
2626
val_every_steps: 1000
2727
max_steps: 10
2828

2929
dist_env:
3030
backend: nccl
31-
timeout_minutes: 1
31+
timeout_minutes: 10
3232

3333
rng:
3434
_target_: nemo_automodel.components.training.rng.StatefulRNG
@@ -55,21 +55,18 @@ checkpoint:
5555

5656
distributed:
5757
strategy: fsdp2
58-
dp_size: 4
59-
tp_size: 1
58+
dp_size: null
59+
tp_size: 2
6060
cp_size: 1
61-
pp_size: 4
61+
pp_size: 1
6262

6363
sequence_parallel: false
6464
activation_checkpointing: true
65-
66-
pipeline:
67-
pp_schedule: interleaved1f1b
68-
pp_microbatch_size: 1
69-
layers_per_stage: 2
70-
scale_grads_in_schedule: false
71-
round_virtual_stages_to_pp_multiple: up
72-
dtype: bf16
65+
enable_async_tensor_parallel: false
66+
enable_fsdp2_prefetch: true
67+
enable_compile: true
68+
defer_fsdp_grad_sync: false
69+
patch_is_packed_sequence: true # Patch transformers._is_packed_sequence to always return False: removes CPU-GPU sync per attention layer and ensures static shapes for torch.compile. Safe for non-packed (standard) training only.
7370

7471
loss_fn:
7572
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
@@ -84,7 +81,6 @@ dataset:
8481
dataloader:
8582
_target_: torch.utils.data.DataLoader
8683
batch_size: null # Dataset already yields batches
87-
# Note: model_config will be auto-injected by train_ft.py for PP models
8884

8985
optimizer:
9086
_target_: torch.optim.Adam
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# To run (from the repository root):
17+
# Intended layout: 64 ranks (8 nodes x 8 GPUs; matches tp=1, cp=2, auto dp=32 in distributed.*).
18+
# The interactive `automodel` CLI only launches a single node. For 8 nodes use torchrun or your
19+
# cluster's multi-process wrapper, e.g.:
20+
# torchrun --nnodes=8 --nproc_per_node=8 --node_rank=<rank> --master_addr=<host> --master_port=<port> \
21+
# -m nemo_automodel.cli.app \
22+
# examples/llm_pretrain/custom_llama3_1_70b_pretrain_benchmark_8nodes.yaml
23+
24+
recipe: PretrainRecipeForNextTokenPrediction
25+
26+
seed: 42
27+
28+
benchmark:
29+
warmup_steps: 5
30+
peak_tflops: 989 # H100: 989, A100: 312
31+
nsys_start: -1
32+
nsys_end: -1
33+
nsys_ranks: []
34+
num_nodes: 8
35+
36+
step_scheduler:
37+
global_batch_size: 128
38+
local_batch_size: 1 # dp=32: grad_acc=4 steps per rank
39+
ckpt_every_steps: 2000
40+
num_epochs: 1
41+
max_steps: 10
42+
43+
dist_env:
44+
backend: nccl
45+
timeout_minutes: 60
46+
47+
model:
48+
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
49+
pretrained_model_name_or_path: meta-llama/Llama-3.1-70B
50+
torch_dtype: bf16
51+
trust_remote_code: True
52+
backend:
53+
_target_: nemo_automodel.components.models.common.BackendConfig
54+
rms_norm: torch_fp32
55+
56+
checkpoint:
57+
enabled: False
58+
59+
dataset:
60+
_target_: nemo_automodel.components.datasets.llm.mock_iterable_dataset.MockIterableDataset
61+
vocab_size: 100
62+
seq_len: 8192
63+
num_samples: 1000000
64+
batch_size: 1 # Must match step_scheduler.local_batch_size
65+
66+
dataloader:
67+
_target_: torch.utils.data.DataLoader
68+
batch_size: null
69+
70+
loss_fn:
71+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
72+
73+
optimizer:
74+
_target_: transformer_engine.pytorch.optimizers.FusedAdam
75+
lr: 0.0002
76+
betas: [0.9, 0.95]
77+
weight_decay: 0.1
78+
adam_w_mode: True
79+
80+
distributed:
81+
strategy: fsdp2
82+
dp_size: null # auto: 64 / (tp1 * cp2 * pp1) = 32
83+
tp_size: 1
84+
cp_size: 2
85+
pp_size: 1
86+
sequence_parallel: False
87+
activation_checkpointing: True
88+
enable_async_tensor_parallel: False
89+
enable_fsdp2_prefetch: True
90+
enable_compile: True
91+
patch_is_packed_sequence: True # Patch transformers._is_packed_sequence to always return False: removes CPU-GPU sync per attention layer and ensures static shapes for torch.compile. Safe for non-packed (standard) training only.
92+
defer_fsdp_grad_sync: False # Must be False with GA>1: True → delayed resharding → OOM
93+
defer_rs_grad_accum: True # GA=4: replaces 4× RS with 1× AllReduce (saves 2× RS bandwidth)

nemo_automodel/_transformers/infrastructure.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def apply_model_infrastructure(
422422
autopipeline=autopipeline,
423423
tp_size=mesh.tp_size,
424424
ep_size=mesh.ep_size,
425+
dp_shard_size=mesh.dp_shard_size,
425426
pretrained_model_name_or_path=pretrained_model_name_or_path,
426427
load_base_model=load_base_model,
427428
peft_config=peft_config,
@@ -513,6 +514,17 @@ def apply_model_infrastructure(
513514
load_base_model=load_base_model,
514515
)
515516

517+
# Apply per-layer torch.compile after checkpoint loading so that the _orig_mod key prefix
518+
# introduced by torch.compile doesn't conflict with HF checkpoint key names.
519+
if isinstance(model_wrapper, FSDP2Manager) and (
520+
model_wrapper.enable_compile or model_wrapper.enable_async_tensor_parallel
521+
):
522+
from nemo_automodel.components.distributed.parallelizer import _apply_per_layer_compile
523+
524+
model_parts = model.parts if hasattr(model, "parts") else [model]
525+
for mp in model_parts:
526+
_apply_per_layer_compile(mp)
527+
516528
# Freeze parameters after checkpoint loading and parallelization
517529
# This catches params created during parallelization (e.g., GroupedExpertsTE in init_token_dispatcher)
518530
if peft_config is not None:

nemo_automodel/_transformers/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ def _should_load_before_shard(
2626
autopipeline: Optional[object],
2727
tp_size: int,
2828
ep_size: int,
29+
dp_shard_size: int = 1,
2930
pretrained_model_name_or_path: str,
3031
load_base_model: bool,
3132
peft_config: Optional[object],
3233
) -> bool:
3334
"""Decide whether to load the checkpoint before FSDP/TP/EP sharding.
3435
35-
Load-before-shard is only safe when running single-GPU (no PP, TP, or EP)
36-
and a checkpoint actually needs loading.
36+
Load-before-shard is only safe when running single-GPU (no PP, TP, EP, or
37+
DP sharding) and a checkpoint actually needs loading.
3738
With any model parallelism the post-shard load path must be used to avoid
3839
NCCL collective mismatches or key/device inconsistencies.
3940
@@ -43,12 +44,13 @@ def _should_load_before_shard(
4344
no_pp = autopipeline is None
4445
no_tp = tp_size <= 1
4546
no_ep = ep_size <= 1
47+
no_dp_shard = dp_shard_size <= 1
4648
no_peft = peft_config is None
4749
need_checkpoint_load = bool(pretrained_model_name_or_path and load_base_model)
48-
result = no_pp and no_tp and no_ep and no_peft and need_checkpoint_load
50+
result = no_pp and no_tp and no_ep and no_dp_shard and no_peft and need_checkpoint_load
4951
logger.debug(
50-
"[_should_load_before_shard] no_pp={} no_tp={} no_ep={} no_peft={} need_load={} -> {}".format(
51-
no_pp, no_tp, no_ep, no_peft, need_checkpoint_load, result
52+
"[_should_load_before_shard] no_pp={} no_tp={} no_ep={} no_dp_shard={} no_peft={} need_load={} -> {}".format(
53+
no_pp, no_tp, no_ep, no_dp_shard, no_peft, need_checkpoint_load, result
5254
)
5355
)
5456
return result

0 commit comments

Comments
 (0)