Skip to content

Commit 5fbb559

Browse files
huvunvidiaHuy Vu2
andauthored
Enable nemo-ci tests (short runs - perf and non-perf) for Wan + Updating recipes names (#3179)
Co-authored-by: Huy Vu2 <huvu@login-eos02.eos.clusters.nvidia.com>
1 parent ad27e2c commit 5fbb559

19 files changed

Lines changed: 319 additions & 211 deletions

File tree

examples/diffusion/recipes/wan/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,15 @@ WAN uses different flow-matching hyperparameters for pretraining vs fine-tuning.
146146

147147
```bash
148148
uv run torchrun --nproc_per_node=8 scripts/training/run_recipe.py \
149-
--recipe wan_1_3B_pretrain_config \
149+
--recipe wan_1_3b_pretrain_config \
150150
--step_func wan_step
151151
```
152152

153153
### WAN 1.3B — Real data (WebDataset path):
154154

155155
```bash
156156
uv run torchrun --nproc_per_node=8 scripts/training/run_recipe.py \
157-
--recipe wan_1_3B_pretrain_config \
157+
--recipe wan_1_3b_pretrain_config \
158158
--step_func wan_step \
159159
dataset.path=${WORKSPACE}/datasets/wan
160160
```
@@ -163,7 +163,7 @@ uv run torchrun --nproc_per_node=8 scripts/training/run_recipe.py \
163163

164164
```bash
165165
uv run torchrun --nproc_per_node=$NUM_GPUS scripts/training/run_recipe.py \
166-
--recipe wan_1_3B_pretrain_config \
166+
--recipe wan_1_3b_pretrain_config \
167167
--step_func wan_step \
168168
dataset.path=${WORKSPACE}/datasets/wan \
169169
train.global_batch_size=8 \

examples/diffusion/recipes/wan/conf/gb200_perf_pretrain_mock.yaml

Lines changed: 0 additions & 33 deletions
This file was deleted.

examples/diffusion/recipes/wan/conf/gb300_perf_pretrain_mock.yaml

Lines changed: 0 additions & 33 deletions
This file was deleted.

examples/diffusion/recipes/wan/conf/h100_perf_pretrain_mock.yaml

Lines changed: 0 additions & 37 deletions
This file was deleted.

examples/diffusion/recipes/wan/prepare_dataset/openvid1M_dataset/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ CHECKPOINT_DIR=<path/to/save/checkpoints>
7373
EXP_NAME=<experiment_name>
7474

7575
NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 scripts/training/run_recipe.py \
76-
--recipe wan_1_3B_pretrain_config \
76+
--recipe wan_1_3b_pretrain_config \
7777
--step_func wan_step \
7878
model.tensor_model_parallel_size=1 \
7979
model.pipeline_model_parallel_size=1 \

scripts/performance/argument_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def parse_cli_args():
170170
parser.add_argument(
171171
"--domain",
172172
type=lower_str,
173-
choices=["llm", "vlm", "qwen3vl"],
173+
choices=["llm", "vlm", "qwen3vl", "diffusion"],
174174
help="Domain to use for experiment.",
175175
default="llm",
176176
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
try:
16+
import megatron.bridge # noqa: F401
17+
18+
HAVE_MEGATRON_BRIDGE = True
19+
except ModuleNotFoundError:
20+
HAVE_MEGATRON_BRIDGE = False
21+
22+
if HAVE_MEGATRON_BRIDGE:
23+
from .wan_diffusion_pretrain import (
24+
wan_14b_pretrain_config_gb200,
25+
wan_14b_pretrain_config_h100,
26+
)
27+
28+
from .wan_workload_base_configs import (
29+
WAN_14B_PRETRAIN_CONFIG_GB200_BF16_V1,
30+
WAN_14B_PRETRAIN_CONFIG_H100_BF16_V1,
31+
)
32+
33+
34+
__all__ = [
35+
"WAN_14B_PRETRAIN_CONFIG_GB200_BF16_V1",
36+
"WAN_14B_PRETRAIN_CONFIG_H100_BF16_V1",
37+
]
38+
39+
if HAVE_MEGATRON_BRIDGE:
40+
__all__.extend(
41+
[
42+
"wan_14b_pretrain_config_gb200",
43+
"wan_14b_pretrain_config_h100",
44+
]
45+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
17+
from utils.overrides import set_workload_base_configs
18+
from utils.utils import get_workload_base_config
19+
20+
from megatron.bridge.diffusion.recipes.wan.wan import wan_14b_pretrain_config
21+
from megatron.bridge.training.config import ConfigContainer
22+
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
# Wan 14B pretrain configs ---------------------------------------------------
28+
29+
30+
def wan_14b_pretrain_config_gb200(
31+
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
32+
) -> ConfigContainer:
33+
"""GB200, Wan 14B pretrain: TP=1, CP=4, GBS=64."""
34+
base_cfg = get_workload_base_config(
35+
model_family_name="wan",
36+
model_recipe_name="wan_14b",
37+
gpu="gb200",
38+
compute_dtype=precision.upper(),
39+
task="pretrain",
40+
config_variant=config_variant,
41+
)
42+
cfg = wan_14b_pretrain_config()
43+
set_workload_base_configs(cfg, base_cfg)
44+
return cfg
45+
46+
47+
def wan_14b_pretrain_config_h100(
48+
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
49+
) -> ConfigContainer:
50+
"""H100, Wan 14B pretrain: TP=2, CP=4, GBS=128, activation recompute (block/8 layers)."""
51+
base_cfg = get_workload_base_config(
52+
model_family_name="wan",
53+
model_recipe_name="wan_14b",
54+
gpu="h100",
55+
compute_dtype=precision.upper(),
56+
task="pretrain",
57+
config_variant=config_variant,
58+
)
59+
cfg = wan_14b_pretrain_config()
60+
set_workload_base_configs(cfg, base_cfg)
61+
return cfg
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Parallelism presets for Wan 14B performance configs.
16+
17+
Config naming convention:
18+
{MODEL}_{SIZE}_{TASK}_CONFIG_{GPU}_{PRECISION}_{VERSION}
19+
20+
All configs use bf16 precision (diffusion training does not use fp8).
21+
Parallelism settings are sourced from the per-GPU YAML perf configs in
22+
examples/diffusion/recipes/wan/conf/.
23+
"""
24+
25+
from dataclasses import replace
26+
27+
from utils.utils import WorkloadBaseConfig
28+
29+
30+
BASE_WAN_14B_CONFIG = WorkloadBaseConfig(
31+
num_gpus=8,
32+
global_batch_size=64,
33+
micro_batch_size=1,
34+
)
35+
36+
# =============================================================================
37+
# Wan 14B pretrain presets
38+
# =============================================================================
39+
40+
# GB200: 16 GPUs (4 nodes), TP=1, CP=4, DP=4, GBS=64
41+
WAN_14B_PRETRAIN_CONFIG_GB200_BF16_V1 = replace(
42+
BASE_WAN_14B_CONFIG,
43+
num_gpus=16,
44+
tensor_model_parallel_size=1,
45+
context_parallel_size=4,
46+
)
47+
48+
# H100: 32 GPUs (4 nodes), TP=2, CP=4, DP=4, GBS=64, activation recompute (block/8 layers)
49+
WAN_14B_PRETRAIN_CONFIG_H100_BF16_V1 = replace(
50+
BASE_WAN_14B_CONFIG,
51+
num_gpus=32,
52+
tensor_model_parallel_size=2,
53+
context_parallel_size=4,
54+
recompute_num_layers=8,
55+
)

0 commit comments

Comments
 (0)