Skip to content

Commit 23f7df6

Browse files
svc-bionemoclaude
andcommitted
Add FSDP2 + Expert Parallelism tests for mixtral recipes
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com>
1 parent 74968c6 commit 23f7df6

4 files changed

Lines changed: 727 additions & 0 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared test utilities for distributed (EP/FSDP) tests."""
17+
18+
import os
19+
import sys
20+
from dataclasses import dataclass, field
21+
from pathlib import Path
22+
23+
import torch
24+
25+
26+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
27+
28+
from modeling_mixtral_te import NVMixtralConfig
29+
30+
31+
def create_small_mixtral_config(**overrides) -> NVMixtralConfig:
32+
"""Create a small Mixtral config suitable for testing."""
33+
defaults = {
34+
"hidden_size": 128,
35+
"intermediate_size": 256,
36+
"num_hidden_layers": 2,
37+
"num_attention_heads": 4,
38+
"num_key_value_heads": 2,
39+
"num_local_experts": 4,
40+
"num_experts_per_tok": 2,
41+
"max_position_embeddings": 128,
42+
"vocab_size": 1000,
43+
"attn_input_format": "bshd",
44+
"self_attn_mask_type": "causal",
45+
"router_jitter_noise": 0.0,
46+
}
47+
defaults.update(overrides)
48+
return NVMixtralConfig(**defaults)
49+
50+
51+
def get_dummy_batch(vocab_size: int, seq_len: int = 32, batch_size: int = 2, device: str = "cuda"):
52+
"""Create a simple dummy batch for testing."""
53+
torch.manual_seed(42)
54+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
55+
attention_mask = torch.ones_like(input_ids)
56+
labels = input_ids.clone()
57+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
58+
59+
60+
@dataclass(frozen=True)
61+
class DistributedConfig:
62+
"""Distributed environment configuration."""
63+
64+
rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0")))
65+
local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0")))
66+
world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1")))
67+
_master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost"))
68+
_master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355"))
69+
70+
def is_main_process(self) -> bool:
71+
"""Return True if this is the global rank 0 process."""
72+
return self.rank == 0
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for FSDP2 + Expert Parallelism (EP) in the mixtral_native_te recipe.
17+
18+
Verifies that FSDP2 and EP can be composed together:
19+
- FSDP=2, EP=1 (2 GPUs): Data-parallel sharding, all experts on each rank.
20+
- FSDP=1, EP=2 (2 GPUs): Expert-parallel training, no data parallelism.
21+
"""
22+
23+
import subprocess
24+
import sys
25+
from pathlib import Path
26+
27+
28+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
29+
sys.path.insert(0, str(Path(__file__).resolve().parent))
30+
31+
import pytest
32+
import torch
33+
from distributed_helpers import DistributedConfig, create_small_mixtral_config, get_dummy_batch
34+
from modeling_mixtral_te import NVMixtralForCausalLM
35+
36+
37+
requires_2_gpus = pytest.mark.skipif(
38+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
39+
reason="Test requires at least 2 GPUs",
40+
)
41+
42+
43+
def _distribute_state_dict(full_state_dict: dict, model: torch.nn.Module, device: torch.device) -> dict:
44+
"""Distribute a full (EP=1) state dict to match a model's DTensor sharding.
45+
46+
After calling ``set_ep_groups``, expert weight parameters become DTensors with
47+
``Shard(0)`` placement. This function uses ``distribute_tensor`` to automatically
48+
shard full expert weights according to those annotations, avoiding manual slicing.
49+
50+
Args:
51+
full_state_dict: Complete state dict from an EP=1 model (plain tensors).
52+
model: Target EP model whose expert parameters are already DTensors.
53+
device: Device to move source tensors to before distributing.
54+
"""
55+
from torch.distributed.tensor import DTensor, distribute_tensor
56+
57+
distributed_state: dict = {}
58+
# model.state_dict() filters _extra_state keys via the NVMixtralPreTrainedModel
59+
# override, so use nn.Module.state_dict to get the unfiltered dict that includes
60+
# TransformerEngine _extra_state entries required by load_state_dict(strict=True).
61+
for key, value in torch.nn.Module.state_dict(model).items():
62+
if key.endswith("_extra_state"):
63+
distributed_state[key] = value
64+
elif key not in full_state_dict:
65+
continue
66+
elif isinstance(value, DTensor):
67+
distributed_state[key] = distribute_tensor(
68+
full_state_dict[key].to(device),
69+
value.device_mesh,
70+
list(value.placements),
71+
)
72+
else:
73+
distributed_state[key] = full_state_dict[key]
74+
return distributed_state
75+
76+
77+
def _train_step(model, batch):
78+
"""Run a single forward + backward + optimizer step.
79+
80+
Returns:
81+
Tuple of (loss value, dict of gradient norms, dict of weight change norms).
82+
"""
83+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
84+
85+
# Snapshot weights before step
86+
pre_weights = {n: p.detach().clone() for n, p in model.named_parameters()}
87+
88+
optimizer.zero_grad()
89+
outputs = model(**batch)
90+
loss = outputs.loss
91+
loss.backward()
92+
93+
grad_norms = {}
94+
for name, param in model.named_parameters():
95+
if param.grad is not None:
96+
g = param.grad
97+
if hasattr(g, "full_tensor"):
98+
g = g.full_tensor()
99+
grad_norms[name] = g.detach().float().norm().item()
100+
101+
optimizer.step()
102+
103+
# Measure weight changes
104+
weight_changes = {}
105+
for name, param in model.named_parameters():
106+
pre = pre_weights[name]
107+
cur = param.detach()
108+
if hasattr(pre, "full_tensor"):
109+
pre = pre.full_tensor()
110+
if hasattr(cur, "full_tensor"):
111+
cur = cur.full_tensor()
112+
weight_changes[name] = (cur.float() - pre.float()).norm().item()
113+
114+
return loss.detach().item(), grad_norms, weight_changes
115+
116+
117+
# ---------------------------------------------------------------------------
118+
# Pytest entry points — launch torchrun subprocesses
119+
# ---------------------------------------------------------------------------
120+
121+
122+
def _run_torchrun(test_fn_name: str, port: int, nproc: int = 2):
123+
"""Run a named worker function via torchrun."""
124+
recipe_dir = str(Path(__file__).resolve().parent.parent)
125+
script = str(Path(__file__).resolve())
126+
cmd = [
127+
"torchrun",
128+
f"--nproc_per_node={nproc}",
129+
"--rdzv-backend=c10d",
130+
f"--rdzv-endpoint=localhost:{port}",
131+
script,
132+
test_fn_name,
133+
]
134+
result = subprocess.run(
135+
cmd,
136+
check=False,
137+
text=True,
138+
cwd=recipe_dir,
139+
stdout=subprocess.PIPE,
140+
stderr=subprocess.PIPE,
141+
timeout=300,
142+
)
143+
if result.returncode != 0:
144+
print(f"STDOUT:\n{result.stdout}")
145+
print(f"STDERR:\n{result.stderr}")
146+
pytest.fail(f"{test_fn_name} failed with exit code {result.returncode}")
147+
148+
149+
@requires_2_gpus
150+
def test_fsdp2_ep1(free_tcp_port):
151+
"""Test FSDP=2, EP=1: data-parallel training with all experts on each rank."""
152+
_run_torchrun("fsdp2_ep1", free_tcp_port, nproc=2)
153+
154+
155+
@requires_2_gpus
156+
def test_fsdp1_ep2(free_tcp_port):
157+
"""Test FSDP=1, EP=2: expert-parallel training without data parallelism."""
158+
_run_torchrun("fsdp1_ep2", free_tcp_port, nproc=2)
159+
160+
161+
# ---------------------------------------------------------------------------
162+
# Distributed workers executed via torchrun
163+
# ---------------------------------------------------------------------------
164+
165+
166+
def _worker_fsdp2_ep1():
167+
"""FSDP=2, EP=1: weights sharded by FSDP, all experts on each rank.
168+
169+
Uses a 2D device mesh (dp=2, ep=1) so that DTensor multi-dimensional
170+
placement logic is exercised even though the EP dimension is trivial.
171+
172+
1. Init distributed, create 2D device mesh with ep=1.
173+
2. Create model with EP=1, set EP groups on the trivial EP sub-mesh.
174+
3. Wrap with FSDP2 on the DP sub-mesh.
175+
4. Run one training step, verify loss/gradients are finite and weights update.
176+
"""
177+
from torch.distributed.device_mesh import init_device_mesh
178+
from torch.distributed.fsdp import fully_shard
179+
180+
dist_config = DistributedConfig()
181+
device = torch.device(f"cuda:{dist_config.local_rank}")
182+
torch.cuda.set_device(device)
183+
torch.distributed.init_process_group(backend="nccl", device_id=device)
184+
185+
ep_size = 1
186+
dp_size = dist_config.world_size
187+
device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep"))
188+
189+
config = create_small_mixtral_config(expert_parallel_size=ep_size)
190+
torch.manual_seed(0)
191+
model = NVMixtralForCausalLM(config).to(dtype=torch.bfloat16, device=device)
192+
193+
# EP setup with trivial (size-1) EP sub-mesh
194+
ep_mesh = device_mesh["ep"]
195+
ep_group = ep_mesh.get_group()
196+
model.model.set_ep_groups(ep_group, ep_mesh)
197+
198+
# FSDP2 wrapping on DP sub-mesh
199+
for layer in model.model.layers:
200+
fully_shard(layer, mesh=device_mesh["dp"])
201+
fully_shard(model, mesh=device_mesh["dp"])
202+
203+
model.train()
204+
batch = get_dummy_batch(config.vocab_size, device=str(device))
205+
206+
loss_val, grad_norms, weight_changes = _train_step(model, batch)
207+
208+
assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}"
209+
assert len(grad_norms) > 0, "No gradients computed"
210+
for name, gnorm in grad_norms.items():
211+
assert torch.isfinite(torch.tensor(gnorm)), f"Gradient for {name} is not finite: {gnorm}"
212+
assert any(wc > 0 for wc in weight_changes.values()), "No weights updated after optimizer step"
213+
214+
torch.distributed.destroy_process_group()
215+
216+
217+
def _worker_fsdp1_ep2():
218+
"""FSDP=1, EP=2: experts sharded across ranks, trivial data parallelism.
219+
220+
Uses a 2D device mesh (dp=1, ep=2) so that DTensor multi-dimensional
221+
placement logic is exercised even though the DP dimension is trivial.
222+
223+
1. Init distributed, create 2D device mesh with dp=1.
224+
2. Create full EP=1 model for reference weights.
225+
3. Create EP=2 model, set EP groups (DTensor annotations), load via distribute_tensor.
226+
4. Wrap with FSDP2 on the trivial DP sub-mesh.
227+
5. Run one training step, verify loss/gradients are finite and weights update.
228+
"""
229+
from torch.distributed.device_mesh import init_device_mesh
230+
from torch.distributed.fsdp import fully_shard
231+
232+
dist_config = DistributedConfig()
233+
device = torch.device(f"cuda:{dist_config.local_rank}")
234+
torch.cuda.set_device(device)
235+
torch.distributed.init_process_group(backend="nccl", device_id=device)
236+
237+
ep_size = dist_config.world_size
238+
dp_size = 1
239+
device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, ep_size), mesh_dim_names=("dp", "ep"))
240+
241+
ep_mesh = device_mesh["ep"]
242+
ep_group = ep_mesh.get_group()
243+
244+
# Get reference weights from a full EP=1 model
245+
config_full = create_small_mixtral_config(expert_parallel_size=1)
246+
torch.manual_seed(0)
247+
full_model = NVMixtralForCausalLM(config_full).to(dtype=torch.bfloat16, device="cpu")
248+
full_state_dict = {k: v.clone() for k, v in full_model.state_dict().items()}
249+
del full_model
250+
251+
# Create EP=2 model, set EP groups to create DTensor annotations, then load weights
252+
config_ep = create_small_mixtral_config(expert_parallel_size=ep_size)
253+
torch.manual_seed(0)
254+
model = NVMixtralForCausalLM(config_ep).to(dtype=torch.bfloat16, device=device)
255+
256+
# EP setup on EP sub-mesh first (creates DTensor annotations on expert weights)
257+
model.model.set_ep_groups(ep_group, ep_mesh)
258+
259+
# Load EP=1 weights — distribute_tensor handles expert sharding automatically
260+
distributed_state = _distribute_state_dict(full_state_dict, model, device)
261+
model.load_state_dict(distributed_state, strict=True)
262+
263+
# FSDP2 wrapping on trivial (size-1) DP sub-mesh
264+
for layer in model.model.layers:
265+
fully_shard(layer, mesh=device_mesh["dp"])
266+
fully_shard(model, mesh=device_mesh["dp"])
267+
268+
model.train()
269+
batch = get_dummy_batch(config_ep.vocab_size, device=str(device))
270+
271+
loss_val, grad_norms, weight_changes = _train_step(model, batch)
272+
273+
assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}"
274+
assert len(grad_norms) > 0, "No gradients computed"
275+
for name, gnorm in grad_norms.items():
276+
assert torch.isfinite(torch.tensor(gnorm)), f"Gradient for {name} is not finite: {gnorm}"
277+
assert any(wc > 0 for wc in weight_changes.values()), "No weights updated after optimizer step"
278+
279+
torch.distributed.destroy_process_group()
280+
281+
282+
if __name__ == "__main__":
283+
test_name = sys.argv[1]
284+
285+
workers = {
286+
"fsdp2_ep1": _worker_fsdp2_ep1,
287+
"fsdp1_ep2": _worker_fsdp1_ep2,
288+
}
289+
workers[test_name]()

0 commit comments

Comments
 (0)