Skip to content

Commit bb62f94

Browse files
trvachovclaude
andcommitted
Add FSDP2 + Expert Parallelism to Mixtral training recipes
- Add 2D device mesh (dp, ep) to both mixtral_native_te and opengenome2_mixtral_native_te train_fsdp2.py scripts - Call model.model.set_ep_groups() before fully_shard() to convert expert weights to DTensors with Shard(0) on the ep dimension - Pass expert_parallel_size to NVMixtralConfig so num_local_experts is set correctly per rank (num_experts // ep_size) - Add clip_grad_norm_ep_aware() helper that handles DTensor expert weight gradients on different device meshes (avoids aten.stack mesh mismatch error in torch.nn.utils.clip_grad_norm_) - Add expert_parallel_size config field to both defaults.yaml files - Update L0_sanity.yaml in both recipes for EP=2 on 2-GPU setup, W&B run names agent1-opengenome2 and agent1-lingua, project swarm-mixtral-development Validated on 2x H100 (CUDA_VISIBLE_DEVICES=2,3): - OG2: loss 5.37→1.22, W&B run agent1-opengenome2 - Lingua: loss 11.8→8.5, W&B run agent1-lingua Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
1 parent 7cc1e78 commit bb62f94

6 files changed

Lines changed: 152 additions & 13 deletions

File tree

bionemo-recipes/recipes/mixtral_native_te/hydra_config/L0_sanity.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ config_kwargs:
1717
self_attn_mask_type: "causal"
1818
router_jitter_noise: 0.0
1919

20-
num_train_steps: 250
20+
num_train_steps: 20
2121

2222
use_torch_compile: false
23-
use_meta_device: true
23+
use_meta_device: false # small model fits on device directly; avoids meta-device complexity with EP
24+
25+
# Expert parallelism: EP=2 on 2-GPU setup (dp=1, ep=2).
26+
# num_local_experts (4) must be divisible by expert_parallel_size (2): 4/2=2 experts/rank.
27+
expert_parallel_size: 2
2428

2529
dataset:
2630
tokenizer_name_or_path: nvidia/Llama-3.1-8B-Instruct-FP8
@@ -36,8 +40,8 @@ dataset:
3640
streaming: true
3741

3842
wandb:
39-
name: "mixtral_8x1B_sanity"
40-
mode: "offline"
43+
name: "agent1-lingua"
44+
project: "swarm-mixtral-development"
4145

4246
lr_scheduler_kwargs:
4347
num_warmup_steps: 10

bionemo-recipes/recipes/mixtral_native_te/hydra_config/defaults.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use_meta_device: true
1010
use_torch_compile: false
1111
use_sequence_packing: false
1212

13+
# Expert parallelism: number of GPUs per expert-parallel group.
14+
# Must divide world_size evenly. Set > 1 to enable MoE expert parallelism.
15+
expert_parallel_size: 1
16+
1317
dataset:
1418
tokenizer_name_or_path: ???
1519
micro_batch_size: 2

bionemo-recipes/recipes/mixtral_native_te/train_fsdp2.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
from contextlib import nullcontext
2121
from pathlib import Path
22+
from typing import Iterable
2223

2324
import hydra
2425
import nvdlfw_inspect.api as debug_api
@@ -50,6 +51,48 @@
5051
logger.setLevel(logging.INFO)
5152

5253

54+
def clip_grad_norm_ep_aware(params: Iterable[torch.nn.Parameter], max_norm: float, ep_size: int) -> torch.Tensor:
55+
"""Clip gradient norms, handling expert parallelism (DTensor parameters on different meshes).
56+
57+
When ep_size > 1, parameters may be DTensors on different device meshes (dp vs ep),
58+
which prevents torch.nn.utils.clip_grad_norm_ from stacking norms across them.
59+
This function computes norms per-parameter from local shards and clips accordingly.
60+
61+
Args:
62+
params: Model parameters (may include DTensor expert weights).
63+
ep_size: Expert parallelism size. If 1, falls back to standard clip_grad_norm_.
64+
max_norm: Maximum gradient norm.
65+
66+
Returns:
67+
Total gradient norm (approximate for ep_size > 1).
68+
"""
69+
if ep_size == 1:
70+
return torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm)
71+
72+
# Compute per-param local norms, handling DTensor by extracting the local shard.
73+
param_list = list(params)
74+
norms = []
75+
for p in param_list:
76+
if p.grad is None:
77+
continue
78+
g = p.grad.detach()
79+
if hasattr(g, "to_local"):
80+
g = g.to_local() # Extract local shard of DTensor gradient
81+
norms.append(g.float().norm())
82+
83+
if not norms:
84+
return torch.tensor(0.0)
85+
86+
total_norm = torch.stack(norms).norm()
87+
clip_coef = max_norm / (total_norm + 1e-6)
88+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
89+
for p in param_list:
90+
if p.grad is not None:
91+
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
92+
93+
return total_norm
94+
95+
5396
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
5497
def main(args: DictConfig) -> float | None:
5598
"""Train Mixtral with TE layers using FSDP2."""
@@ -62,7 +105,13 @@ def main(args: DictConfig) -> float | None:
62105
if args.fp8_stats_config.enabled:
63106
initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
64107

65-
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
108+
ep_size = args.expert_parallel_size
109+
if dist_config.world_size % ep_size != 0:
110+
raise ValueError(
111+
f"world_size ({dist_config.world_size}) must be divisible by expert_parallel_size ({ep_size})"
112+
)
113+
dp_size = dist_config.world_size // ep_size
114+
device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep"))
66115

67116
fp8_recipe = None
68117
if args.fp8_config.enabled:
@@ -75,7 +124,14 @@ def main(args: DictConfig) -> float | None:
75124
fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
76125

77126
if args.use_te:
78-
config = NVMixtralConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
127+
# Pass expert_parallel_size to config so the model initializes with the correct
128+
# num_local_experts = num_experts // expert_parallel_size per rank.
129+
config = NVMixtralConfig.from_pretrained(
130+
args.config_name_or_path,
131+
dtype=torch.bfloat16,
132+
expert_parallel_size=ep_size,
133+
**args.config_kwargs,
134+
)
79135
with torch.device("meta") if args.use_meta_device else nullcontext():
80136
model = NVMixtralForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
81137
else:
@@ -85,6 +141,13 @@ def main(args: DictConfig) -> float | None:
85141

86142
logger.info("Initialized Model:\n%s", model)
87143

144+
# Expert parallelism setup — MUST happen before fully_shard()
145+
# Wraps expert weights as DTensors with Shard(0) on the expert dimension.
146+
if args.use_te and ep_size > 1:
147+
ep_mesh = device_mesh["ep"]
148+
ep_group = ep_mesh.get_group()
149+
model.model.set_ep_groups(ep_group, ep_mesh)
150+
88151
for layer in model.model.layers:
89152
fully_shard(layer, mesh=device_mesh["dp"])
90153
fully_shard(model, mesh=device_mesh["dp"])
@@ -152,7 +215,7 @@ def main(args: DictConfig) -> float | None:
152215
if micro_step % args.grad_acc_steps == 0:
153216
micro_step = 0
154217

155-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
218+
total_norm = clip_grad_norm_ep_aware(model.parameters(), max_norm=1.0, ep_size=ep_size)
156219

157220
optimizer.step()
158221
scheduler.step()

bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/L0_sanity.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@ config_kwargs:
1717
self_attn_mask_type: causal
1818
router_jitter_noise: 0.0
1919

20-
num_train_steps: 250
20+
num_train_steps: 20
2121

2222
use_torch_compile: false
23-
use_meta_device: true
23+
use_meta_device: false # small model fits on device directly; avoids meta-device complexity with EP
2424
use_fp32_master_weights: false
2525

26+
# Expert parallelism: EP=2 on 2-GPU setup (dp=1, ep=2).
27+
# num_local_experts (4) must be divisible by expert_parallel_size (2): 4/2=2 experts/rank.
28+
expert_parallel_size: 2
29+
2630
dataset:
2731
tokenizer_name_or_path: ./tokenizers/nucleotide_fast_tokenizer
2832
micro_batch_size: 1
@@ -38,8 +42,8 @@ dataset:
3842
streaming: true
3943

4044
wandb:
41-
name: og2_mixtral_sanity
42-
mode: offline
45+
name: "agent1-opengenome2"
46+
project: "swarm-mixtral-development"
4347

4448
lr_scheduler_kwargs:
4549
num_warmup_steps: 10

bionemo-recipes/recipes/opengenome2_mixtral_native_te/hydra_config/defaults.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ use_meta_device: false
1717
use_torch_compile: false
1818
use_sequence_packing: false
1919

20+
# Expert parallelism: number of GPUs per expert-parallel group.
21+
# Must divide world_size evenly. Set > 1 to enable MoE expert parallelism.
22+
expert_parallel_size: 1
23+
2024
dataset:
2125
tokenizer_name_or_path: ???
2226
micro_batch_size: 8

bionemo-recipes/recipes/opengenome2_mixtral_native_te/train_fsdp2.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import random
2121
from contextlib import nullcontext
2222
from pathlib import Path
23+
from typing import Iterable
2324

2425
import hydra
2526
import numpy as np
@@ -63,6 +64,48 @@
6364
logger.setLevel(logging.INFO)
6465

6566

67+
def clip_grad_norm_ep_aware(params: Iterable[torch.nn.Parameter], max_norm: float, ep_size: int) -> torch.Tensor:
68+
"""Clip gradient norms, handling expert parallelism (DTensor parameters on different meshes).
69+
70+
When ep_size > 1, parameters may be DTensors on different device meshes (dp vs ep),
71+
which prevents torch.nn.utils.clip_grad_norm_ from stacking norms across them.
72+
This function computes norms per-parameter from local shards and clips accordingly.
73+
74+
Args:
75+
params: Model parameters (may include DTensor expert weights).
76+
max_norm: Maximum gradient norm.
77+
ep_size: Expert parallelism size. If 1, falls back to standard clip_grad_norm_.
78+
79+
Returns:
80+
Total gradient norm (approximate for ep_size > 1).
81+
"""
82+
if ep_size == 1:
83+
return torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm)
84+
85+
# Compute per-param local norms, handling DTensor by extracting the local shard.
86+
param_list = list(params)
87+
norms = []
88+
for p in param_list:
89+
if p.grad is None:
90+
continue
91+
g = p.grad.detach()
92+
if hasattr(g, "to_local"):
93+
g = g.to_local() # Extract local shard of DTensor gradient
94+
norms.append(g.float().norm())
95+
96+
if not norms:
97+
return torch.tensor(0.0)
98+
99+
total_norm = torch.stack(norms).norm()
100+
clip_coef = max_norm / (total_norm + 1e-6)
101+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
102+
for p in param_list:
103+
if p.grad is not None:
104+
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
105+
106+
return total_norm
107+
108+
66109
def set_seed(seed: int) -> None:
67110
"""Set random seeds for reproducibility.
68111
@@ -103,7 +146,13 @@ def main(args: DictConfig) -> float | None:
103146
if args.fp8_stats_config.enabled:
104147
initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
105148

106-
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
149+
ep_size = args.expert_parallel_size
150+
if dist_config.world_size % ep_size != 0:
151+
raise ValueError(
152+
f"world_size ({dist_config.world_size}) must be divisible by expert_parallel_size ({ep_size})"
153+
)
154+
dp_size = dist_config.world_size // ep_size
155+
device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep"))
107156

108157
# Create an FP8 recipe -- this is only used if FP8 is enabled in the config.
109158
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
@@ -125,6 +174,10 @@ def main(args: DictConfig) -> float | None:
125174
logger.info("FP32 master weights enabled: model init in FP32")
126175

127176
config_kwargs = OmegaConf.to_container(args.config_kwargs, resolve=True) if args.config_kwargs else {}
177+
# Pass expert_parallel_size to config so the model initializes with the correct
178+
# num_local_experts = num_experts // expert_parallel_size per rank.
179+
if args.use_te:
180+
config_kwargs["expert_parallel_size"] = ep_size
128181

129182
config = config_class.from_pretrained(args.config_name_or_path, dtype=model_dtype, **config_kwargs)
130183

@@ -146,6 +199,13 @@ def main(args: DictConfig) -> float | None:
146199

147200
logger.info("Initialized Model:\n%s", model)
148201

202+
# Expert parallelism setup — MUST happen before fully_shard()
203+
# Wraps expert weights as DTensors with Shard(0) on the expert dimension.
204+
if args.use_te and ep_size > 1:
205+
ep_mesh = device_mesh["ep"]
206+
ep_group = ep_mesh.get_group()
207+
model.model.set_ep_groups(ep_group, ep_mesh)
208+
149209
# Create MixedPrecisionPolicy for FSDP when using FP32 master weights
150210
mp_policy = None
151211
if use_fp32_master_weights:
@@ -288,7 +348,7 @@ def main(args: DictConfig) -> float | None:
288348
if micro_step % args.grad_acc_steps == 0:
289349
micro_step = 0
290350

291-
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
351+
total_norm = clip_grad_norm_ep_aware(model.parameters(), max_norm=1.0, ep_size=ep_size)
292352

293353
optimizer.step()
294354
scheduler.step()

0 commit comments

Comments
 (0)