|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: LicenseRef-Apache2 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +# --- BEGIN COPIED FILE NOTICE --- |
| 17 | +# This file is copied from: bionemo-recipes/models/mixtral/fused_a2a.py |
| 18 | +# Do not modify this file directly. Instead, modify the source and run: |
| 19 | +# python ci/scripts/check_copied_files.py --fix |
| 20 | +# --- END COPIED FILE NOTICE --- |
| 21 | + |
| 22 | +# Portions of this code are from DeepSeek DeepEP project |
| 23 | +# Copyright (c) 2025 DeepSeek |
| 24 | +# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE |
| 25 | + |
| 26 | +import os |
| 27 | + |
| 28 | + |
| 29 | +try: |
| 30 | + from deep_ep import Buffer |
| 31 | + from deep_ep.utils import EventHandle, EventOverlap |
| 32 | + |
| 33 | + HAVE_DEEP_EP = True |
| 34 | + Buffer.set_num_sms(int(os.environ.get("DEEP_EP_SM_NUMS", "20"))) |
| 35 | +except ImportError: |
| 36 | + HAVE_DEEP_EP = False |
| 37 | + |
| 38 | +import torch |
| 39 | + |
| 40 | + |
| 41 | +_buffer = None |
| 42 | +_nvshmem_available = None |
| 43 | + |
| 44 | + |
| 45 | +def _is_nvshmem_available() -> bool: |
| 46 | + """Check if DeepEP was compiled with NVSHMEM support. |
| 47 | +
|
| 48 | + Probes NVSHMEM by calling get_rdma_buffer_size_hint, since |
| 49 | + is_sm90_compiled() alone is not a reliable proxy — SM90 can |
| 50 | + be compiled while NVSHMEM is still disabled. |
| 51 | + """ |
| 52 | + global _nvshmem_available # noqa: PLW0603 |
| 53 | + if _nvshmem_available is None: |
| 54 | + try: |
| 55 | + config = Buffer.get_dispatch_config(2) |
| 56 | + config.get_rdma_buffer_size_hint(256, 2) |
| 57 | + _nvshmem_available = True |
| 58 | + except RuntimeError: |
| 59 | + _nvshmem_available = False |
| 60 | + return _nvshmem_available |
| 61 | + |
| 62 | + |
| 63 | +def get_hidden_bytes(x: torch.Tensor) -> int: |
| 64 | + """Calculate the number of hidden bytes for a tensor. |
| 65 | +
|
| 66 | + Args: |
| 67 | + x (torch.Tensor): Input tensor |
| 68 | +
|
| 69 | + Returns: |
| 70 | + int: Number of hidden bytes |
| 71 | + """ |
| 72 | + return x.size(1) * max(x.element_size(), 2) |
| 73 | + |
| 74 | + |
| 75 | +def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int): |
| 76 | + """Get or create a buffer for all-to-all communication. |
| 77 | +
|
| 78 | + Args: |
| 79 | + group (torch.distributed.ProcessGroup): Process group for communication |
| 80 | + hidden_bytes (int): Number of hidden bytes needed |
| 81 | +
|
| 82 | + Returns: |
| 83 | + Buffer: Communication buffer |
| 84 | + """ |
| 85 | + global _buffer # noqa: PLW0603 |
| 86 | + num_nvl_bytes, num_rdma_bytes = 0, 0 |
| 87 | + nvshmem = _is_nvshmem_available() |
| 88 | + for config in ( |
| 89 | + Buffer.get_dispatch_config(group.size()), |
| 90 | + Buffer.get_combine_config(group.size()), |
| 91 | + ): |
| 92 | + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) |
| 93 | + if nvshmem: |
| 94 | + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) |
| 95 | + |
| 96 | + # Allocate buffer if not existed or not enough buffer |
| 97 | + # NOTES: the adaptive routing configuration of the network **must be off** |
| 98 | + if ( |
| 99 | + _buffer is None |
| 100 | + or _buffer.group != group |
| 101 | + or _buffer.num_nvl_bytes < num_nvl_bytes |
| 102 | + or _buffer.num_rdma_bytes < num_rdma_bytes |
| 103 | + ): |
| 104 | + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) |
| 105 | + return _buffer |
| 106 | + |
| 107 | + |
| 108 | +class FusedDispatch(torch.autograd.Function): |
| 109 | + """Fused dispatch operation for MoE routing combining computation and communication.""" |
| 110 | + |
| 111 | + @staticmethod |
| 112 | + def forward( |
| 113 | + ctx, |
| 114 | + x, |
| 115 | + token_indices, |
| 116 | + token_probs, |
| 117 | + num_experts, |
| 118 | + group, |
| 119 | + async_finish=False, |
| 120 | + allocate_on_comm_stream=False, |
| 121 | + ): |
| 122 | + """Forward pass of fused dispatch.""" |
| 123 | + previous_event = None |
| 124 | + if async_finish: |
| 125 | + previous_event = EventOverlap(EventHandle()) |
| 126 | + # Calculate layout before actual dispatch |
| 127 | + buffer = get_buffer(group, get_hidden_bytes(x)) |
| 128 | + ( |
| 129 | + num_tokens_per_rank, |
| 130 | + num_tokens_per_rdma_rank, |
| 131 | + num_tokens_per_expert, |
| 132 | + is_token_in_rank, |
| 133 | + event, |
| 134 | + ) = buffer.get_dispatch_layout( |
| 135 | + token_indices, |
| 136 | + num_experts, |
| 137 | + previous_event=previous_event, |
| 138 | + async_finish=async_finish, |
| 139 | + allocate_on_comm_stream=allocate_on_comm_stream, |
| 140 | + ) |
| 141 | + |
| 142 | + # Do MoE dispatch |
| 143 | + # NOTES: the CPU will wait for GPU's signal to arrive, |
| 144 | + # so this is not compatible with CUDA graph |
| 145 | + ( |
| 146 | + recv_x, |
| 147 | + recv_token_indices, |
| 148 | + recv_token_probs, |
| 149 | + num_recv_tokens_per_expert_list, |
| 150 | + handle, |
| 151 | + after_event_overlap, |
| 152 | + ) = buffer.dispatch( |
| 153 | + x, |
| 154 | + topk_idx=token_indices, |
| 155 | + topk_weights=token_probs, # DeepEP only supports float32 probs |
| 156 | + num_tokens_per_rank=num_tokens_per_rank, |
| 157 | + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, |
| 158 | + is_token_in_rank=is_token_in_rank, |
| 159 | + num_tokens_per_expert=num_tokens_per_expert, |
| 160 | + previous_event=event, # wait in deepep::intra/inter_dispatch |
| 161 | + async_finish=async_finish, |
| 162 | + allocate_on_comm_stream=allocate_on_comm_stream, |
| 163 | + ) |
| 164 | + |
| 165 | + # Make sure current stream is synchronized |
| 166 | + if async_finish: |
| 167 | + after_event_overlap.current_stream_wait() |
| 168 | + |
| 169 | + # Save for backward |
| 170 | + ctx.group = group |
| 171 | + ctx.handle = handle |
| 172 | + ctx.async_finish = async_finish |
| 173 | + ctx.allocate_on_comm_stream = allocate_on_comm_stream |
| 174 | + tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list) |
| 175 | + |
| 176 | + return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle) |
| 177 | + |
| 178 | + @staticmethod |
| 179 | + def backward( |
| 180 | + ctx, |
| 181 | + grad_output, |
| 182 | + grad_token_indices, |
| 183 | + grad_token_probs, |
| 184 | + grad_tokens_per_expert, |
| 185 | + grad_handle, |
| 186 | + ): |
| 187 | + """Backward pass of fused dispatch.""" |
| 188 | + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) |
| 189 | + handle = ctx.handle |
| 190 | + previous_event = None |
| 191 | + if ctx.async_finish: |
| 192 | + previous_event = EventOverlap(EventHandle()) |
| 193 | + grad_x, grad_token_probs, after_event = buffer.combine( |
| 194 | + grad_output.contiguous(), |
| 195 | + handle, |
| 196 | + topk_weights=grad_token_probs.float(), |
| 197 | + previous_event=previous_event, |
| 198 | + async_finish=ctx.async_finish, |
| 199 | + allocate_on_comm_stream=ctx.allocate_on_comm_stream, |
| 200 | + ) |
| 201 | + # Make sure current stream is synchronized |
| 202 | + if ctx.async_finish: |
| 203 | + after_event.current_stream_wait() |
| 204 | + return grad_x, None, grad_token_probs, None, None, None, None |
| 205 | + |
| 206 | + |
| 207 | +class FusedCombine(torch.autograd.Function): |
| 208 | + """Fused combine operation for MoE output combining computation and communication.""" |
| 209 | + |
| 210 | + @staticmethod |
| 211 | + def forward(ctx, x, group, handle, async_finish=False, allocate_on_comm_stream=False): |
| 212 | + """Forward pass of fused combine.""" |
| 213 | + previous_event = None |
| 214 | + if async_finish: |
| 215 | + previous_event = EventOverlap(EventHandle()) |
| 216 | + buffer = get_buffer(group, get_hidden_bytes(x)) |
| 217 | + combined_x, _, after_event = buffer.combine( |
| 218 | + x, |
| 219 | + handle=handle, |
| 220 | + async_finish=async_finish, |
| 221 | + previous_event=previous_event, |
| 222 | + allocate_on_comm_stream=allocate_on_comm_stream, |
| 223 | + ) |
| 224 | + # Make sure current stream is synchronized |
| 225 | + if async_finish: |
| 226 | + after_event.current_stream_wait() |
| 227 | + |
| 228 | + ctx.handle = handle |
| 229 | + ctx.group = group |
| 230 | + ctx.async_finish = async_finish |
| 231 | + ctx.allocate_on_comm_stream = allocate_on_comm_stream |
| 232 | + return combined_x, None |
| 233 | + |
| 234 | + @staticmethod |
| 235 | + def backward(ctx, grad_output, previous_event=None): |
| 236 | + """Backward pass of fused combine.""" |
| 237 | + previous_event = None |
| 238 | + if ctx.async_finish: |
| 239 | + previous_event = EventOverlap(EventHandle()) |
| 240 | + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) |
| 241 | + grad_x, _, _, _, _, after_event = buffer.dispatch( |
| 242 | + grad_output.contiguous(), |
| 243 | + handle=ctx.handle, |
| 244 | + previous_event=previous_event, |
| 245 | + async_finish=ctx.async_finish, |
| 246 | + allocate_on_comm_stream=ctx.allocate_on_comm_stream, |
| 247 | + ) |
| 248 | + # Make sure current stream is synchronized |
| 249 | + if ctx.async_finish: |
| 250 | + after_event.current_stream_wait() |
| 251 | + return grad_x, None, None, None, None |
| 252 | + |
| 253 | + |
| 254 | +if HAVE_DEEP_EP: |
| 255 | + |
| 256 | + def fused_dispatch( |
| 257 | + x, |
| 258 | + token_indices, |
| 259 | + token_probs, |
| 260 | + num_experts, |
| 261 | + group, |
| 262 | + async_finish=False, |
| 263 | + allocate_on_comm_stream=False, |
| 264 | + ): |
| 265 | + """Perform fused dispatch operation if deep_ep is available. |
| 266 | +
|
| 267 | + Args: |
| 268 | + x: Input tensor [num_tokens, hidden_size] |
| 269 | + token_indices: Token routing indices [num_tokens, topk] |
| 270 | + token_probs: Token routing probabilities [num_tokens, topk] |
| 271 | + num_experts: Number of experts |
| 272 | + group: Process group |
| 273 | + async_finish: Whether to finish asynchronously |
| 274 | + allocate_on_comm_stream: Whether to allocate on communication stream |
| 275 | +
|
| 276 | + Returns: |
| 277 | + Result of FusedDispatch |
| 278 | + """ |
| 279 | + return FusedDispatch.apply( |
| 280 | + x.contiguous(), |
| 281 | + token_indices, |
| 282 | + token_probs, |
| 283 | + num_experts, |
| 284 | + group, |
| 285 | + async_finish, |
| 286 | + allocate_on_comm_stream, |
| 287 | + ) |
| 288 | + |
| 289 | + def fused_combine(x, group, handle, async_finish=False, allocate_on_comm_stream=False): |
| 290 | + """Perform fused combine operation if deep_ep is available. |
| 291 | +
|
| 292 | + Args: |
| 293 | + x: Input tensor |
| 294 | + group: Process group |
| 295 | + handle: Communication handle |
| 296 | + async_finish: Whether to finish asynchronously |
| 297 | + allocate_on_comm_stream: Whether to allocate on communication stream |
| 298 | +
|
| 299 | + Returns: |
| 300 | + Result of FusedCombine |
| 301 | + """ |
| 302 | + return FusedCombine.apply(x, group, handle, async_finish, allocate_on_comm_stream) |
| 303 | + |
| 304 | +else: |
| 305 | + fused_dispatch = None |
| 306 | + fused_combine = None |
0 commit comments