Skip to content

Commit f280bac

Browse files
trvachovclaude
andcommitted
Add Expert Parallelism configs, fused token dispatch, and 8x7B model support for Mixtral recipes
Adds EP test configs (EP1/2/4/8) for both Lingua and OpenGenome2 Mixtral recipes, fused all-to-all token dispatch modules, 8x7B model configs, and EP-aware gradient clipping in train_fsdp2.py. Validated on 8x H100 80GB: - 8x1B: all EP degrees (1,2,4,8) pass for both datasets - 8x7B: OG2 EP2 passes (seq_len=2048), Lingua OOMs (seq_len=4096) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bb62f94 commit f280bac

24 files changed

Lines changed: 2201 additions & 2 deletions
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# --- BEGIN COPIED FILE NOTICE ---
17+
# This file is copied from: bionemo-recipes/models/mixtral/fused_a2a.py
18+
# Do not modify this file directly. Instead, modify the source and run:
19+
# python ci/scripts/check_copied_files.py --fix
20+
# --- END COPIED FILE NOTICE ---
21+
22+
# Portions of this code are from DeepSeek DeepEP project
23+
# Copyright (c) 2025 DeepSeek
24+
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
25+
26+
import os
27+
28+
29+
try:
30+
from deep_ep import Buffer
31+
from deep_ep.utils import EventHandle, EventOverlap
32+
33+
HAVE_DEEP_EP = True
34+
Buffer.set_num_sms(int(os.environ.get("DEEP_EP_SM_NUMS", "20")))
35+
except ImportError:
36+
HAVE_DEEP_EP = False
37+
38+
import torch
39+
40+
41+
_buffer = None
42+
_nvshmem_available = None
43+
44+
45+
def _is_nvshmem_available() -> bool:
46+
"""Check if DeepEP was compiled with NVSHMEM support.
47+
48+
Probes NVSHMEM by calling get_rdma_buffer_size_hint, since
49+
is_sm90_compiled() alone is not a reliable proxy — SM90 can
50+
be compiled while NVSHMEM is still disabled.
51+
"""
52+
global _nvshmem_available # noqa: PLW0603
53+
if _nvshmem_available is None:
54+
try:
55+
config = Buffer.get_dispatch_config(2)
56+
config.get_rdma_buffer_size_hint(256, 2)
57+
_nvshmem_available = True
58+
except RuntimeError:
59+
_nvshmem_available = False
60+
return _nvshmem_available
61+
62+
63+
def get_hidden_bytes(x: torch.Tensor) -> int:
64+
"""Calculate the number of hidden bytes for a tensor.
65+
66+
Args:
67+
x (torch.Tensor): Input tensor
68+
69+
Returns:
70+
int: Number of hidden bytes
71+
"""
72+
return x.size(1) * max(x.element_size(), 2)
73+
74+
75+
def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int):
76+
"""Get or create a buffer for all-to-all communication.
77+
78+
Args:
79+
group (torch.distributed.ProcessGroup): Process group for communication
80+
hidden_bytes (int): Number of hidden bytes needed
81+
82+
Returns:
83+
Buffer: Communication buffer
84+
"""
85+
global _buffer # noqa: PLW0603
86+
num_nvl_bytes, num_rdma_bytes = 0, 0
87+
nvshmem = _is_nvshmem_available()
88+
for config in (
89+
Buffer.get_dispatch_config(group.size()),
90+
Buffer.get_combine_config(group.size()),
91+
):
92+
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
93+
if nvshmem:
94+
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
95+
96+
# Allocate buffer if not existed or not enough buffer
97+
# NOTES: the adaptive routing configuration of the network **must be off**
98+
if (
99+
_buffer is None
100+
or _buffer.group != group
101+
or _buffer.num_nvl_bytes < num_nvl_bytes
102+
or _buffer.num_rdma_bytes < num_rdma_bytes
103+
):
104+
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
105+
return _buffer
106+
107+
108+
class FusedDispatch(torch.autograd.Function):
109+
"""Fused dispatch operation for MoE routing combining computation and communication."""
110+
111+
@staticmethod
112+
def forward(
113+
ctx,
114+
x,
115+
token_indices,
116+
token_probs,
117+
num_experts,
118+
group,
119+
async_finish=False,
120+
allocate_on_comm_stream=False,
121+
):
122+
"""Forward pass of fused dispatch."""
123+
previous_event = None
124+
if async_finish:
125+
previous_event = EventOverlap(EventHandle())
126+
# Calculate layout before actual dispatch
127+
buffer = get_buffer(group, get_hidden_bytes(x))
128+
(
129+
num_tokens_per_rank,
130+
num_tokens_per_rdma_rank,
131+
num_tokens_per_expert,
132+
is_token_in_rank,
133+
event,
134+
) = buffer.get_dispatch_layout(
135+
token_indices,
136+
num_experts,
137+
previous_event=previous_event,
138+
async_finish=async_finish,
139+
allocate_on_comm_stream=allocate_on_comm_stream,
140+
)
141+
142+
# Do MoE dispatch
143+
# NOTES: the CPU will wait for GPU's signal to arrive,
144+
# so this is not compatible with CUDA graph
145+
(
146+
recv_x,
147+
recv_token_indices,
148+
recv_token_probs,
149+
num_recv_tokens_per_expert_list,
150+
handle,
151+
after_event_overlap,
152+
) = buffer.dispatch(
153+
x,
154+
topk_idx=token_indices,
155+
topk_weights=token_probs, # DeepEP only supports float32 probs
156+
num_tokens_per_rank=num_tokens_per_rank,
157+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
158+
is_token_in_rank=is_token_in_rank,
159+
num_tokens_per_expert=num_tokens_per_expert,
160+
previous_event=event, # wait in deepep::intra/inter_dispatch
161+
async_finish=async_finish,
162+
allocate_on_comm_stream=allocate_on_comm_stream,
163+
)
164+
165+
# Make sure current stream is synchronized
166+
if async_finish:
167+
after_event_overlap.current_stream_wait()
168+
169+
# Save for backward
170+
ctx.group = group
171+
ctx.handle = handle
172+
ctx.async_finish = async_finish
173+
ctx.allocate_on_comm_stream = allocate_on_comm_stream
174+
tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list)
175+
176+
return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle)
177+
178+
@staticmethod
179+
def backward(
180+
ctx,
181+
grad_output,
182+
grad_token_indices,
183+
grad_token_probs,
184+
grad_tokens_per_expert,
185+
grad_handle,
186+
):
187+
"""Backward pass of fused dispatch."""
188+
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
189+
handle = ctx.handle
190+
previous_event = None
191+
if ctx.async_finish:
192+
previous_event = EventOverlap(EventHandle())
193+
grad_x, grad_token_probs, after_event = buffer.combine(
194+
grad_output.contiguous(),
195+
handle,
196+
topk_weights=grad_token_probs.float(),
197+
previous_event=previous_event,
198+
async_finish=ctx.async_finish,
199+
allocate_on_comm_stream=ctx.allocate_on_comm_stream,
200+
)
201+
# Make sure current stream is synchronized
202+
if ctx.async_finish:
203+
after_event.current_stream_wait()
204+
return grad_x, None, grad_token_probs, None, None, None, None
205+
206+
207+
class FusedCombine(torch.autograd.Function):
208+
"""Fused combine operation for MoE output combining computation and communication."""
209+
210+
@staticmethod
211+
def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False):
212+
"""Forward pass of fused combine."""
213+
previous_event = None
214+
if async_finish:
215+
previous_event = EventOverlap(EventHandle())
216+
buffer = get_buffer(group, get_hidden_bytes(x))
217+
combined_x, _, after_event = buffer.combine(
218+
x,
219+
handle=handle,
220+
async_finish=async_finish,
221+
previous_event=previous_event,
222+
allocate_on_comm_stream=allocate_on_comm_stream,
223+
)
224+
# Make sure current stream is synchronized
225+
if async_finish:
226+
after_event.current_stream_wait()
227+
228+
ctx.handle = handle
229+
ctx.group = group
230+
ctx.async_finish = async_finish
231+
ctx.allocate_on_comm_stream = allocate_on_comm_stream
232+
return combined_x, None
233+
234+
@staticmethod
235+
def backward(ctx, grad_output, previous_event=None):
236+
"""Backward pass of fused combine."""
237+
previous_event = None
238+
if ctx.async_finish:
239+
previous_event = EventOverlap(EventHandle())
240+
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
241+
grad_x, _, _, _, _, after_event = buffer.dispatch(
242+
grad_output.contiguous(),
243+
handle=ctx.handle,
244+
previous_event=previous_event,
245+
async_finish=ctx.async_finish,
246+
allocate_on_comm_stream=ctx.allocate_on_comm_stream,
247+
)
248+
# Make sure current stream is synchronized
249+
if ctx.async_finish:
250+
after_event.current_stream_wait()
251+
return grad_x, None, None, None, None
252+
253+
254+
if HAVE_DEEP_EP:
255+
256+
def fused_dispatch(
257+
x,
258+
token_indices,
259+
token_probs,
260+
num_experts,
261+
group,
262+
async_finish=False,
263+
allocate_on_comm_stream=False,
264+
):
265+
"""Perform fused dispatch operation if deep_ep is available.
266+
267+
Args:
268+
x: Input tensor [num_tokens, hidden_size]
269+
token_indices: Token routing indices [num_tokens, topk]
270+
token_probs: Token routing probabilities [num_tokens, topk]
271+
num_experts: Number of experts
272+
group: Process group
273+
async_finish: Whether to finish asynchronously
274+
allocate_on_comm_stream: Whether to allocate on communication stream
275+
276+
Returns:
277+
Result of FusedDispatch
278+
"""
279+
return FusedDispatch.apply(
280+
x.contiguous(),
281+
token_indices,
282+
token_probs,
283+
num_experts,
284+
group,
285+
async_finish,
286+
allocate_on_comm_stream,
287+
)
288+
289+
def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream=False):
290+
"""Perform fused combine operation if deep_ep is available.
291+
292+
Args:
293+
x: Input tensor
294+
group: Process group
295+
handle: Communication handle
296+
async_finish: Whether to finish asynchronously
297+
allocate_on_comm_stream: Whether to allocate on communication stream
298+
299+
Returns:
300+
Result of FusedCombine
301+
"""
302+
return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream)
303+
304+
else:
305+
fused_dispatch = None
306+
fused_combine = None

0 commit comments

Comments
 (0)