diff --git a/requirements.txt b/requirements.txt index 99fe8fcb8cbe..afe70060e446 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,8 @@ tensorrt~=10.15.1 torch>=2.10.0,<=2.11.0a0 torchvision nvidia-modelopt[torch]~=0.37.0 +# NcclEP uses nccl4py's nccl.ep package without changing the NCCL wheel constraint. +nccl4py>=0.3 # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-26-02.html#rel-26-02 uses 2.29.2 # torch 2.10.0+cu130 depends on nvidia-nccl-cu13==2.28.9 nvidia-nccl-cu13>=2.28.9,<=2.29.2 diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py index 0d44ecd2df1e..9858693c5f37 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py @@ -34,6 +34,7 @@ from .communication_factory import CommunicationFactory from .deep_ep import DeepEP from .deep_ep_low_latency import DeepEPLowLatency +from .nccl_ep import NcclEP from .nvlink_one_sided import NVLinkOneSided from .nvlink_two_sided import NVLinkTwoSided @@ -46,6 +47,7 @@ "NVLinkOneSided", "DeepEP", "DeepEPLowLatency", + "NcclEP", # Factory "CommunicationFactory", ] diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index a4ec2ceefe44..40c2668258e9 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,6 +32,7 @@ from .base import Communication from .deep_ep import DeepEP from .deep_ep_low_latency import DeepEPLowLatency +from .nccl_ep import NcclEP from .nvlink_one_sided import NVLinkOneSided from .nvlink_two_sided import NVLinkTwoSided from .nvlink_two_sided_flashinfer import NVLinkTwoSidedFlashinfer @@ -67,6 +68,7 @@ def create_strategy( 2. Auto-selection (tries in order): - NVLinkOneSided (highest priority for throughput) - NVLinkTwoSided (high priority for latency) + - NcclEP (if nccl-ep is available) - DeepEP (if enabled via TRTLLM_CAN_USE_DEEP_EP) - DeepEPLowLatency (if enabled via TRTLLM_CAN_USE_DEEP_EP) - AllGather + ReduceScatter (fallback, always works) @@ -129,7 +131,7 @@ def create_strategy( ) # Auto-selection: Try strategies in priority order using try-catch - # Priority: NVLinkOneSided > NVLinkTwoSided > DeepEP > DeepEPLowLatency > AllGather + # Priority: NVLinkOneSided > NVLinkTwoSided > NcclEP > DeepEP > DeepEPLowLatency > AllGather try: enable_eplb = model_config.moe_load_balancer is not None @@ -181,6 +183,26 @@ def create_strategy( except Exception as e: logger.info(f"NVLinkTwoSided not available: {e}") + # Try NCCL EP (rank-major LL). Falls through to DeepEP/AllGather if + # prerequisites are not met or libnccl_ep.so is not available. + nccl_ep_unavailable_reason = CommunicationFactory._get_nccl_ep_unavailable_reason(act_dtype) + if nccl_ep_unavailable_reason is None: + try: + strategy = NcclEP( + mapping, + num_slots, + hidden_size, + max_num_tokens, + moe_max_num_tokens, + top_k=top_k, + ) + logger.info("Selected communication strategy: NcclEP") + return strategy + except RuntimeError as e: + logger.debug(f"NcclEP not available: {e}") + else: + logger.debug(f"NcclEP not available: {nccl_ep_unavailable_reason}") + # Try DeepEP (if enabled and weight dtype is bfloat16) if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "1") == "1" and act_dtype == torch.bfloat16: try: @@ -318,7 +340,29 @@ def _create_forced_method( use_low_precision_combine, moe_max_num_tokens, ) + elif method == "NCCL_EP": + nccl_ep_unavailable_reason = CommunicationFactory._get_nccl_ep_unavailable_reason( + act_dtype + ) + if nccl_ep_unavailable_reason is not None: + raise ValueError(nccl_ep_unavailable_reason) + return NcclEP( + mapping, + num_slots, + hidden_size, + max_num_tokens, + moe_max_num_tokens, + top_k=top_k, + ) elif method == "ALLGATHER": return AllGatherReduceScatter(mapping) else: raise ValueError(f"Unknown communication method: {method}") + + @staticmethod + def _get_nccl_ep_unavailable_reason( + act_dtype: torch.dtype, + ) -> Optional[str]: + if act_dtype != torch.bfloat16: + return f"NcclEP requires act_dtype=torch.bfloat16, got {act_dtype}." + return None diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nccl_ep.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nccl_ep.py new file mode 100644 index 000000000000..a03f845e4f29 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nccl_ep.py @@ -0,0 +1,370 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NCCL EP (Expert Parallelism) Communication Strategy for MoE -- LL rank-major. + +Targets the ``nccl.ep`` Python package shipped in the nccl4py wheel (built +against an NCCL master tree containing ``contrib/nccl_ep``). The dispatch +returns rank-major LL outputs: + + * ``recv_x`` : 3D ``[ep_size, max_tokens_per_rank, hidden]`` bf16, + reshaped to 2D for the downstream MoE pipeline. + * ``recv_topk_idx`` : 2D ``[..., top_k]`` int32 with real expert IDs (-1 for invalid rows) + * ``recv_topk_weights`` : 2D ``[..., top_k]`` float32 (the original router weights) + +This matches NVLinkOneSided's contract directly, so NO +``_modify_output_to_adapt_fused_moe`` adapter is needed. The MoE backend's +``fused_moe`` runs top_k experts per row, applies the weights, and produces one +reduced output per row. ``handle.combine`` then sums per-source-rank +contributions back to the home rank. + +Persistent handle: ``Group.create_handle`` is called ONCE (first dispatch); +subsequent dispatches call ``handle.update(topk_idx, ...)`` to rebind routing. +CUDA-graph capture is supported once the handle exists. +""" + +from typing import List, Optional, Tuple + +import torch + +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +from .base import Communication + +_NCCL_RUNTIME_ERRORS = (RuntimeError, OSError) + + +class NcclEP(Communication): + """NCCL EP Low-Latency rank-major communication strategy for MoE expert parallelism.""" + + def __init__( + self, + mapping: Mapping, + num_slots: int, + hidden_size: int, + max_num_tokens: int = 1024, + moe_max_num_tokens: Optional[int] = None, + top_k: int = 8, + use_fp8: bool = False, + ): + super().__init__(mapping) + + from tensorrt_llm._torch.modules.fused_moe.nccl_ep_utils import is_nccl_ep_installed + + if not is_nccl_ep_installed(): + raise RuntimeError("nccl-ep is not installed.") + + self.num_slots = num_slots + self.num_experts = num_slots + self.hidden_size = hidden_size + self.num_local_experts = num_slots // self.ep_size + self.max_top_k = top_k + self.use_fp8 = use_fp8 + + self.max_tokens_per_rank = ( + max_num_tokens + if moe_max_num_tokens is None + else min(max_num_tokens, moe_max_num_tokens) + ) + self.max_recv_tokens = self.ep_size * self.max_tokens_per_rank + + # Singleton NCCL EP context: owns the EP group, RDMA buffers, and + # persistent OUTPUT Tensor descriptors. Allocate it lazily on first + # dispatch because full-model construction runs under MetaInitMode, + # which redirects torch.empty to the meta device even when a CUDA + # device is passed explicitly. + self._ctx = None + + # Persistent dispatch handle. Created on first dispatch via + # group.create_handle; reused thereafter via handle.update so + # subsequent dispatches are CUDA-graph-safe. + self._handle = None # nccl.ep.Handle | None + self._dispatch_state: dict = {} + + @staticmethod + def is_platform_supported() -> bool: + from tensorrt_llm._torch.modules.fused_moe.nccl_ep_utils import is_nccl_ep_installed + + return is_nccl_ep_installed() + + def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: + if num_chunks > 1: + return False + if max(all_rank_num_tokens) > self.max_tokens_per_rank: + return False + return True + + def supports_post_quant_dispatch(self) -> bool: + # FP8 path: NCCL EP internally quantizes bf16 -> fp8 during dispatch. + return self.use_fp8 + + def _get_context(self): + if self._ctx is None: + if torch.cuda.is_current_stream_capturing(): + raise RuntimeError( + "NcclEP context must be initialized before CUDA graph capture. " + "Run an eager warmup forward before enabling or capturing CUDA graphs." + ) + from nccl.ep import Layout + + from tensorrt_llm._torch.modules.fused_moe.nccl_ep_utils import get_nccl_ep_context + + self._ctx = get_nccl_ep_context( + self.mapping, + self.num_experts, + self.max_tokens_per_rank, + self.hidden_size, + self.max_top_k, + self.use_fp8, + Layout.RANK_MAJOR, + ) + return self._ctx + + def _setup_handle(self, ctx, topk_nd, stream): + """Ensure self._handle exists; rebind topk via handle.update on subsequent calls.""" + if self._handle is None: + if torch.cuda.is_current_stream_capturing(): + raise RuntimeError( + "NcclEP dispatch handle must be initialized before CUDA graph capture. " + "Run an eager warmup forward before enabling or capturing CUDA graphs." + ) + self._handle = ctx.ep_group.create_handle( + ctx.layout, + topk_nd, + stream=stream, + ) + else: + self._handle.update(topk_nd, stream=stream) + return self._handle + + # ------------------------------------------------------------------ + # Dispatch -- rank-major LL + # ------------------------------------------------------------------ + + def dispatch( + self, + hidden_states: torch.Tensor, + hidden_states_sf: Optional[torch.Tensor], + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + all_rank_num_tokens: List[int], + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + """Dispatch tokens via NCCL EP LL rank-major. + + Returns rank-major-shaped tensors directly: + (recv_hs [N, H], recv_sf [N, H/128] or None, recv_slots [N, top_k] int32, + recv_scales [N, top_k] float32) + + where N = ep_size * max_tokens_per_rank. Rows beyond + ``recv_rank_counter[r]`` for source rank r have recv_slots = -1 + (sentinel), naturally skipped by the MoE backend. + """ + from nccl.ep import DispatchConfig, DispatchInputs, DispatchOutputs, LayoutInfo, Tensor + + ctx = self._get_context() + + all_rank_max_num_tokens = max(all_rank_num_tokens) + if all_rank_max_num_tokens > self.max_tokens_per_rank: + raise ValueError( + f"all_rank_max_num_tokens={all_rank_max_num_tokens} > " + f"max_tokens_per_rank={self.max_tokens_per_rank}" + ) + + num_tokens = hidden_states.shape[0] + top_k = token_selected_slots.shape[1] + if top_k > self.max_top_k: + raise ValueError(f"top_k={top_k} exceeds configured max_top_k={self.max_top_k}") + if token_final_scales is None: + raise RuntimeError( + "NcclEP rank-major dispatch requires token_final_scales " + "(router weights) -- it is an INPUT to handle.dispatch." + ) + + stream = ctx.get_stream() + + # TODO(NCCL): topk_weights still requires float32; once bf16/native + # weights are accepted upstream, drop this conversion too. + weights_f32 = ( + token_final_scales + if token_final_scales.dtype == torch.float32 + else token_final_scales.to(torch.float32) + ) + hidden_states_c = hidden_states.contiguous() + weights_f32_c = weights_f32.contiguous() + + input_tokens_nd = Tensor(hidden_states_c) + input_topk_weights_nd = Tensor(weights_f32_c) + + # Mark padding rows with the -1 sentinel so fused_moe skips them. + # The dispatch kernel only writes recv_topk_idx for slots that + # received tokens; rows beyond `recv_rank_counter[r]` keep stale + # data from prior dispatches. recv_rank_counter is written fresh + # by the dispatch kernel (low_latency.cu:877) so it does not need + # pre-zeroing, and recv_topk_weights on -1 rows is don't-care + # (fused_moe ignores the weight when the expert id is -1). + ctx.recv_topk_idx_buf.fill_(-1) + + outputs = DispatchOutputs( + tokens=ctx.output_tokens_nd, + topk_weights=ctx.recv_topk_weights_nd, + topk_idx=ctx.recv_topk_idx_nd, + scales=ctx.scales_nd if self.use_fp8 else None, + ) + layout_info = LayoutInfo(src_rank_counters=ctx.recv_rank_counter_nd) + # If the linked nccl-ep supports it, ask the kernel to emit GLOBAL + # expert ids directly. The high-level LayoutInfo dataclass does not + # surface the field on older wheels, so set it on the underlying + # cybind struct via _lowpp -- on builds without the field, the + # ctx-side probe sets _expert_id_kind_global to None and we leave + # the descriptor untouched (defaults to AUTO == LOCAL). + if ctx._expert_id_kind_global is not None: + layout_info._lowpp.recv_topk_idx_kind = ctx._expert_id_kind_global + + topk_idx_dev = token_selected_slots.to(ctx.topk_idx_dtype).contiguous() + topk_nd = Tensor(topk_idx_dev) + handle = self._setup_handle(ctx, topk_nd, stream) + inputs = DispatchInputs( + tokens=input_tokens_nd, + topk_weights=input_topk_weights_nd, + ) + handle.dispatch( + inputs, + outputs, + layout_info=layout_info, + config=DispatchConfig(round_scales=0), + stream=stream, + ) + + # The handle internally references topk_nd; keep both the Tensor + # descriptor and its backing torch tensor alive until combine completes. + self._dispatch_state = { + "num_tokens": num_tokens, + "topk_nd": topk_nd, + "topk_idx_dev": topk_idx_dev, + } + + # Match NVLinkOneSided's contract: token_selected_slots in + # [0, num_experts) for valid rows, -1 for invalid. When the kernel + # writes GLOBAL ids directly (opt-in detected at ctx init), the + # buffer is already in the right space and we pass it through. + # Otherwise the kernel writes LOCAL ids in [0, num_local_experts) + # and we add ep_rank * num_local_experts to restore the global + # numbering downstream consumers expect. + # The dispatch buffer is 3D [ep_size, max_tokens_per_rank, max_top_k] + # per the LL rank-major contract; flatten to 2D for downstream. + recv_topk_idx_flat = ctx.recv_topk_idx_buf.view(self.max_recv_tokens, self.max_top_k) + if ctx.kernel_writes_global_ids: + recv_slots_global = recv_topk_idx_flat + else: + recv_slots_global = torch.where( + recv_topk_idx_flat >= 0, + recv_topk_idx_flat + self.ep_rank * self.num_local_experts, + recv_topk_idx_flat, + ) + + # Output buffers are 3D [ep_size, max_tokens_per_rank, ...] per the + # LL rank-major contract; downstream MoE pipeline expects 2D -- + # flatten via view. + return ( + ctx.output_tokens_buf.view(self.max_recv_tokens, self.hidden_size), + ctx.scales_buf if self.use_fp8 else None, + recv_slots_global, + ctx.recv_topk_weights_buf.view(self.max_recv_tokens, self.max_top_k), + ) + + # ------------------------------------------------------------------ + # Combine -- rank-major LL + # ------------------------------------------------------------------ + + def combine( + self, + final_hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """Combine MoE-reduced rank-major output back to the home rank. + + Input: [max_recv_tokens, hidden] -- already weighted per-row by fused_moe. + Output: [num_tokens, hidden] -- combined to original token order. + """ + from nccl.ep import CombineInputs, CombineOutputs, Tensor + + ctx = self._ctx + if ctx is None: + raise RuntimeError("NcclEP.combine called before dispatch.") + state = self._dispatch_state + stream = ctx.get_stream() + + num_tokens = state["num_tokens"] + + # Combine input for LL rank-major must be 3D + # [ep_size, max_tokens_per_rank, hidden] -- reshape if caller passed + # 2D [max_recv, H] or a per-expert [E, max_recv, H] layout. + if final_hidden_states.dim() == 3 and final_hidden_states.shape[0] != self.ep_size: + final_hidden_states = final_hidden_states.reshape(-1, self.hidden_size) + if final_hidden_states.dim() == 2: + if final_hidden_states.shape[0] != self.max_recv_tokens: + raise ValueError( + f"combine input rows={final_hidden_states.shape[0]} " + f"expected={self.max_recv_tokens}" + ) + final_hidden_states = final_hidden_states.view( + self.ep_size, + self.max_tokens_per_rank, + self.hidden_size, + ) + + combine_input_c = final_hidden_states.contiguous() + combine_output = torch.empty( + num_tokens, + self.hidden_size, + dtype=torch.bfloat16, + device=combine_input_c.device, + ) + + combine_input_nd = Tensor(combine_input_c) + combine_output_nd = Tensor(combine_output) + + # Rank-major combine: no layout_info, no config required (send_only=0 + # is the default; defaults round-trip fine). + self._handle.combine( + CombineInputs(tokens=combine_input_nd), + CombineOutputs(tokens=combine_output_nd), + stream=stream, + ) + + self._dispatch_state = {} + return combine_output + + def destroy(self): + """Release per-instance NCCL EP resources (handle). + + NcclEpContext is shared across instances and released through a + refcounted cache. + """ + if self._handle is not None: + try: + self._handle.destroy() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"Handle.destroy error during destroy: {e}") + self._handle = None + + from tensorrt_llm._torch.modules.fused_moe.nccl_ep_utils import release_nccl_ep_context + + if self._ctx is not None: + release_nccl_ep_context(self._ctx) + self._ctx = None + self._dispatch_state = {} diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 6b477eb882a7..9aafe56ef9cd 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -103,6 +103,8 @@ class AlltoallMethodType(IntEnum): DeepEP = 3 # DeepEP low latency: CUDA Graphs are supported, IBGDA is required DeepEPLowLatency = 4 + # NCCL EP: Low-latency expert parallelism via NCCL EP library + NcclEP = 5 class MoESchedulerKind(Enum): diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_scheduler.py b/tensorrt_llm/_torch/modules/fused_moe/moe_scheduler.py index 5acfb2d5d66b..4cbabab973c9 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_scheduler.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_scheduler.py @@ -51,7 +51,7 @@ from tensorrt_llm._torch.utils import EventType, Fp4QuantizedTensor from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator -from .communication import DeepEP, DeepEPLowLatency, NVLinkOneSided, NVLinkTwoSided +from .communication import DeepEP, DeepEPLowLatency, NcclEP, NVLinkOneSided, NVLinkTwoSided from .communication.nvlink_two_sided_flashinfer import NVLinkTwoSidedFlashinfer from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE @@ -379,10 +379,9 @@ def _forward_chunk_impl( "Current workaround for apply_router_weight_on_input does not support fp8 input" ) x = x * token_final_scales.to(x.dtype) - # DeepEP variants need a non-None token_final_scales tensor - # (they don't tolerate None), so feed all-ones; other strategies - # accept None and skip the multiply. - if isinstance(moe.comm, (DeepEP, DeepEPLowLatency)): + # These strategies need non-None token_final_scales, so feed + # all-ones after folding the real weights into x. + if isinstance(moe.comm, (DeepEP, DeepEPLowLatency, NcclEP)): token_final_scales = torch.ones_like(token_final_scales) else: token_final_scales = None diff --git a/tensorrt_llm/_torch/modules/fused_moe/nccl_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/nccl_ep_utils.py new file mode 100644 index 000000000000..f2e3143de822 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/nccl_ep_utils.py @@ -0,0 +1,439 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NCCL EP utilities backed by the nccl4py wheel's ``nccl.ep`` package. + +Owns the long-lived NCCL EP resources (communicator, group, persistent receive +NDTensors) for the MoE NcclEP communication strategy. Per-step dispatch handles +are created in ``communication/nccl_ep.py``. ``use_fp8`` gates allocation of +the persistent FP8 scales receive buffer. +""" + +from typing import Optional + +import torch + +from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +_MIN_NCCL_RUNTIME_VERSION = "2.30.4" +_MIN_NCCL_EP_INT32_TOPK_VERSION = "0.2" +_NCCL_RUNTIME_ERRORS = (RuntimeError, OSError) +_NCCL_AVAILABILITY_ERRORS = (ImportError,) + _NCCL_RUNTIME_ERRORS + +_nccl_ep_installed: Optional[bool] = None + + +def is_nccl_ep_installed() -> bool: + """Return True iff ``nccl.ep`` is usable. + + Requires that ``nccl.ep`` imports cleanly AND the loaded ``libnccl.so`` + runtime version is >= 2.30.4. + """ + global _nccl_ep_installed + if _nccl_ep_installed is not None: + return _nccl_ep_installed + try: + import nccl + from packaging.version import Version + + runtime = nccl.get_version().nccl.version + if runtime < Version(_MIN_NCCL_RUNTIME_VERSION): + logger.info( + f"NCCL EP disabled: libnccl runtime {runtime} " + f"< required {_MIN_NCCL_RUNTIME_VERSION}" + ) + _nccl_ep_installed = False + return False + import nccl.ep # noqa: F401 + + _nccl_ep_installed = True + except _NCCL_AVAILABILITY_ERRORS as e: + logger.info(f"NCCL EP disabled: nccl.ep is not usable ({e!r})") + _nccl_ep_installed = False + return _nccl_ep_installed + + +def _nccl_ep_supports_int32_topk_idx() -> bool: + """Return True when the loaded libnccl_ep supports int32 input topk_idx.""" + try: + import nccl + from packaging.version import Version + + nccl_ep_info = nccl.get_version().nccl_ep + nccl_ep_version = nccl_ep_info.version if nccl_ep_info is not None else None + except (ImportError, AttributeError, RuntimeError, OSError) as e: + logger.info( + f"NCCL EP int32 topk_idx disabled: could not determine libnccl_ep version ({e!r})" + ) + return False + + if nccl_ep_version is None: + logger.info("NCCL EP int32 topk_idx disabled: libnccl_ep version is not available") + return False + + if nccl_ep_version < Version(_MIN_NCCL_EP_INT32_TOPK_VERSION): + logger.info( + f"NCCL EP int32 topk_idx disabled: libnccl_ep {nccl_ep_version} " + f"< required {_MIN_NCCL_EP_INT32_TOPK_VERSION}" + ) + return False + + return True + + +# Singleton EP context keyed by (ep_size, ep_rank, max_tokens, num_experts, +# hidden, max_top_k, use_fp8, layout). +_ep_group_cache: dict = {} +_ep_group_refcounts: dict = {} + + +class NcclEpContext: + """Long-lived NCCL EP group + receive buffers, shared across NcclEP instances. + + Owns the :class:`nccl.ep.Group`, the source :class:`nccl.core.Communicator`, + and the rank-major LL persistent receive buffers (tokens, top-k idx / weights, + per-source-rank counter, optional FP8 scales) wrapped as + :class:`nccl.ep.Tensor` descriptors. + + Per-step routing handles (``Handle``) are created in ``NcclEP``, not here. + """ + + def __init__( + self, + mapping: Mapping, + num_experts: int, + max_tokens_per_rank: int, + hidden_size: int, + max_top_k: int, + use_fp8: bool = False, + layout: Optional[int] = None, + ): + import nccl.core as nccl_core + from nccl.ep import Algorithm, Group, GroupConfig, Layout, Tensor + + from tensorrt_llm._utils import mpi_comm + + self.mapping = mapping + self.ep_size = mapping.moe_ep_size + self.ep_rank = mapping.moe_ep_rank + self.num_experts = num_experts + self.num_local_experts = num_experts // self.ep_size + self.max_tokens_per_rank = max_tokens_per_rank + self.max_top_k = max_top_k + self.hidden_size = hidden_size + self.use_fp8 = use_fp8 + self.layout = Layout.RANK_MAJOR if layout is None else Layout(layout) + self.max_recv_tokens = self.ep_size * max_tokens_per_rank + + # topk_idx dtype passed to the EP runtime. NCCL-EP < 0.2 asserts + # int64 in ncclEpUpdateHandle; 0.2+ supports TRT-LLM's native int32 + # routing ids and avoids the per-iter widening conversion. + self.topk_idx_dtype = torch.int32 if _nccl_ep_supports_int32_topk_idx() else torch.int64 + + # Auto-detect whether the linked libnccl_ep.so supports a + # configurable recv_topk_idx kind on LayoutInfo. When the field + # is present we set it to GLOBAL and skip the post-dispatch + # local->global rewrite; otherwise the kernel writes LOCAL ids + # unconditionally (older nccl-ep builds) and the dispatch + # wrapper applies torch.where to restore the global contract + # NVLinkOneSided also advertises. + try: + from nccl.bindings.nccl_ep import ExpertIdKind as _ExpertIdKind + from nccl.bindings.nccl_ep import LayoutInfo as _LowLayoutInfo + + self.kernel_writes_global_ids = hasattr(_LowLayoutInfo(), "recv_topk_idx_kind") + self._expert_id_kind_global = ( + int(_ExpertIdKind.GLOBAL) if self.kernel_writes_global_ids else None + ) + except (ImportError, AttributeError): + self.kernel_writes_global_ids = False + self._expert_id_kind_global = None + + # Capability probe for the opportunistic zero-copy dispatch path. + # When the Pythonic GroupConfig facade exposes `zero_copy` (i.e., + # the wheel was built against a libnccl_ep.so that has the field + # in ncclEpGroupConfig_t), we allocate a VMM-backed, + # window-registered dispatch output buffer; the LL dispatch + # opportunistically picks zero-copy when recv_x->win_hdl is set + # (nvlink-only + rank-major + !fp8). The config flag itself stays + # AUTO/OFF -- strict zero_copy=ON requires combine inputs to be + # windowed too, which would force a caller-side interface change + # (the MLP output is caller-owned). The C-side strict-ON check + # remains in the library for future use. + self.zerocopy_enabled = "zero_copy" in getattr(GroupConfig, "__dataclass_fields__", {}) + + # MPI sub-communicator scoped to the EP group. Mirrors the + # DeepEPLowLatency pattern (see deep_ep_utils.py:104): split + # MPI_COMM_WORLD by pp_rank so each pipeline stage gets its own EP + # comm, keyed by moe_ep_rank. Avoids the wheel's + # nccl.ep.get_nccl_comm_from_group() helper which requires + # torch.distributed.init_process_group() -- the test infrastructure + # (mpi_pool_executor) and microbenchmarks use MPI4PY only. + self._ep_mpi_comm = mpi_comm().Split(mapping.pp_rank, mapping.moe_ep_rank) + ep_world_rank = self._ep_mpi_comm.Get_rank() + ep_world_size = self._ep_mpi_comm.Get_size() + unique_id = nccl_core.get_unique_id() if ep_world_rank == 0 else None + unique_id = self._ep_mpi_comm.bcast(unique_id, root=0) + self.comm = nccl_core.Communicator.init( + nranks=ep_world_size, + rank=ep_world_rank, + unique_id=unique_id, + ) + + cfg = GroupConfig( + algorithm=Algorithm.LOW_LATENCY, + num_experts=num_experts, + max_dispatch_tokens_per_rank=max_tokens_per_rank, + max_recv_tokens_per_rank=self.max_recv_tokens, + max_token_bytes=hidden_size * 2, # bfloat16 + ) + self.ep_group = Group.create(self.comm, cfg) + + logger.info( + f"NCCL EP group created: ep_size={self.ep_size}, " + f"num_experts={num_experts}, max_tokens_per_rank={max_tokens_per_rank}, " + f"hidden_size={hidden_size}, max_top_k={max_top_k}, " + f"layout={self.layout.name}" + ) + + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + + # Dispatch output tokens: 3D [ep_size, max_tokens_per_rank, hidden] + # for LL rank-major. When zerocopy is enabled the buffer must be + # VMM-backed (cuMemMap) so ncclCommWindowRegister's internal + # cuMemGetAddressRange call succeeds -- torch's caching + # allocator returns plain cudaMalloc memory which fails that + # check with CUDA_ERROR_INVALID_VALUE. Allocate via + # nccl.core.mem_alloc (VMM-backed) then build a zero-copy torch + # view over the raw pointer via the TRT-LLM CAI wrapper. + token_shape = (self.ep_size, max_tokens_per_rank, hidden_size) + token_nbytes = self.ep_size * max_tokens_per_rank * hidden_size * 2 + self._output_tokens_alloc = None + self._recv_x_window = None + if self.zerocopy_enabled: + self._output_tokens_alloc = nccl_core.mem_alloc( + token_nbytes, + device=device_id, + ) + self.output_tokens_buf = convert_to_torch_tensor( + TensorWrapper( + int(self._output_tokens_alloc.handle), + dtype=torch.bfloat16, + shape=token_shape, + ) + ) + self._recv_x_window = self.comm.register_window( + self._output_tokens_alloc, + ) + else: + self.output_tokens_buf = torch.empty( + *token_shape, + dtype=torch.bfloat16, + device=device, + ) + # Received topk indices: int32 [ep_size, max_tokens_per_rank, max_top_k] + # for the LL rank-major dispatch contract. -1 marks invalid rows. + # Downstream consumers want 2D [max_recv, max_top_k]; flatten via view. + self.recv_topk_idx_buf = torch.empty( + self.ep_size, + max_tokens_per_rank, + max_top_k, + dtype=torch.int32, + device=device, + ) + # Received topk weights: float32 [ep_size, max_tokens_per_rank, max_top_k] + self.recv_topk_weights_buf = torch.empty( + self.ep_size, + max_tokens_per_rank, + max_top_k, + dtype=torch.float32, + device=device, + ) + # Per-source-rank received-token counter (passed via + # LayoutInfo.src_rank_counters at dispatch time). + self.recv_rank_counter_buf = torch.empty( + self.ep_size, + dtype=torch.int32, + device=device, + ) + # Optional FP8 scales. Even on rank-major LL the kernel writes a 3D + # `[num_local_experts, max_recv, hidden/128]` tensor for scales -- + # callers that want 2D should view-flatten the first dim. + self.scales_buf: Optional[torch.Tensor] = None + if use_fp8: + if hidden_size % 512 != 0: + raise ValueError(f"FP8 dispatch requires hidden % 512 == 0, got {hidden_size}") + self.scales_buf = torch.empty( + self.num_local_experts, + self.max_recv_tokens, + hidden_size // 128, + dtype=torch.float32, + device=device, + ) + + # Wrap each persistent buffer as a Tensor descriptor. Torch owns the + # storage; the descriptor only carries shape + a pointer (+ window + # handle on dispatch output when zerocopy is on, so libnccl_ep's + # opportunistic LL zero-copy path can fire). + if self.zerocopy_enabled and self._recv_x_window is not None: + self.output_tokens_nd = Tensor( + self.output_tokens_buf, + window=self._recv_x_window, + window_offset=0, + ) + else: + self.output_tokens_nd = Tensor(self.output_tokens_buf) + self.recv_topk_idx_nd = Tensor(self.recv_topk_idx_buf) + self.recv_topk_weights_nd = Tensor(self.recv_topk_weights_buf) + self.recv_rank_counter_nd = Tensor(self.recv_rank_counter_buf) + self.scales_nd: Optional[Tensor] = ( + Tensor(self.scales_buf) if self.scales_buf is not None else None + ) + + def get_stream(self) -> int: + """Current CUDA stream as a raw int handle (accepted by ``nccl.ep`` APIs).""" + return torch.cuda.current_stream().cuda_stream + + def destroy(self): + """Release EP group, NCCL comm, and MPI sub-comm in LIFO order. + + Avoids relying on Python GC ordering between the group, the comm it + was built from, and the MPI sub-comm seeding it: the group must go + first (uses the comm), then ``finalize`` + ``destroy`` on the comm + (the recommended nccl4py pattern), then ``Free`` on the MPI comm. + """ + if self.ep_group is not None: + try: + self.ep_group.destroy() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"NCCL EP group destroy error: {e}") + self.ep_group = None + + # Deregister windows before the comm goes away. close() is + # idempotent and local; the comm would auto-close any leftover + # windows on destroy, but explicit LIFO release matches the rest + # of this teardown path. Both dispatch-output and combine-input + # windows are registered only when zerocopy is on. + for attr in ("_combine_input_window", "_recv_x_window"): + w = getattr(self, attr, None) + if w is not None: + try: + w.close() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"NCCL EP window close error ({attr}): {e}") + setattr(self, attr, None) + + # Drop torch view + EP descriptor before freeing the underlying + # NCCL-allocated Buffer (CAI view doesn't refcount the source). + # close() the cuda.core.Buffer to call nccl.core.mem_free; the + # alloc is only populated when zerocopy is enabled. + self.output_tokens_nd = None + self.output_tokens_buf = None + if getattr(self, "_output_tokens_alloc", None) is not None: + try: + self._output_tokens_alloc.close() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"NCCL EP recv_x buffer free error: {e}") + self._output_tokens_alloc = None + + if self.comm is not None: + try: + self.comm.finalize() + self.comm.destroy() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"NCCL EP comm destroy error: {e}") + self.comm = None + + if self._ep_mpi_comm is not None: + from mpi4py import MPI + + try: + self._ep_mpi_comm.Free() + except MPI.Exception as e: + logger.warning(f"EP MPI sub-comm free error: {e}") + self._ep_mpi_comm = None + + +def get_nccl_ep_context( + mapping: Mapping, + num_experts: int, + max_tokens_per_rank: int, + hidden_size: int, + max_top_k: int, + use_fp8: bool = False, + layout: Optional[int] = None, +) -> NcclEpContext: + """Get or create a singleton :class:`NcclEpContext` for the given configuration.""" + from nccl.ep import Layout + + if layout is None: + layout = Layout.RANK_MAJOR + key = ( + mapping.moe_ep_size, + mapping.moe_ep_rank, + max_tokens_per_rank, + num_experts, + hidden_size, + max_top_k, + use_fp8, + int(layout), + ) + if key not in _ep_group_cache: + _ep_group_cache[key] = NcclEpContext( + mapping, + num_experts, + max_tokens_per_rank, + hidden_size, + max_top_k, + use_fp8, + layout, + ) + _ep_group_refcounts[key] = _ep_group_refcounts.get(key, 0) + 1 + return _ep_group_cache[key] + + +def release_nccl_ep_context(ctx: Optional[NcclEpContext]) -> None: + """Release one reference to a cached :class:`NcclEpContext`.""" + if ctx is None: + return + + key = next((key for key, cached_ctx in _ep_group_cache.items() if cached_ctx is ctx), None) + if key is None: + return + + refcount = _ep_group_refcounts.get(key, 0) - 1 + if refcount > 0: + _ep_group_refcounts[key] = refcount + return + + _ep_group_refcounts.pop(key, None) + cached_ctx = _ep_group_cache.pop(key) + try: + cached_ctx.destroy() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"Error destroying NCCL EP context: {e}") + + +def destroy_all_nccl_ep_contexts(): + """Destroy all cached NCCL EP contexts (call at process teardown).""" + for ctx in list(_ep_group_cache.values()): + try: + ctx.destroy() + except _NCCL_RUNTIME_ERRORS as e: + logger.warning(f"Error destroying NCCL EP context: {e}") + _ep_group_cache.clear() + _ep_group_refcounts.clear() diff --git a/tests/microbenchmarks/bench_moe_comm.py b/tests/microbenchmarks/bench_moe_comm.py index e9d8ba839933..d8a1ae708a9f 100644 --- a/tests/microbenchmarks/bench_moe_comm.py +++ b/tests/microbenchmarks/bench_moe_comm.py @@ -783,10 +783,12 @@ def _record_external(event: torch.cuda.Event) -> None: cupti_dispatch = detailed_stats.pop("dispatch_times_us") cupti_combine = detailed_stats.pop("combine_times_us") dispatch_times_us = [ - ct if ct is not None else et for ct, et in zip(cupti_dispatch, dispatch_times_us) + ct if ct is not None else et + for ct, et in zip(cupti_dispatch, dispatch_times_us, strict=True) ] combine_times_us = [ - ct if ct is not None else et for ct, et in zip(cupti_combine, combine_times_us) + ct if ct is not None else et + for ct, et in zip(cupti_combine, combine_times_us, strict=True) ] else: detailed_stats = {"dispatch_kernels": [], "combine_kernels": [], "other_kernels": []} @@ -825,6 +827,99 @@ def _gather_per_rank(times_us: List[float], iter_stats: bool = False) -> Dict[st return {f"rank{i}": (sum(t) / len(t) if t else 0.0) for i, t in enumerate(all_times)} +def _min_local_tokens_for_receiver_coverage(ep_size: int, top_k: int) -> int: + if top_k <= 0: + raise ValueError(f"top_k must be > 0, got {top_k}") + return (ep_size + top_k - 1) // top_k + + +def _scale_local_batch_sizes_for_receiver_coverage( + local_batch_sizes: List[int], ep_size: int, top_k: int +) -> List[int]: + min_tokens = _min_local_tokens_for_receiver_coverage(ep_size, top_k) + scaled: List[int] = [] + for local_num_tokens in local_batch_sizes: + value = max(int(local_num_tokens), min_tokens) + if not scaled or scaled[-1] != value: + scaled.append(value) + return scaled + + +def _verify_dispatch_sentinel( + backend: Communication, + *, + hidden_size: int, + top_k: int, + experts_per_rank: int, + ep_size: int, + act_dtype: torch.dtype, + device: torch.device, + local_num_tokens: Optional[int] = None, +) -> Dict[str, Any]: + """One dispatch+combine with sender-rank-tagged hidden_states. + + Each rank fills its hidden_states with the scalar ``rank + 1``. After + dispatch, each received row should be that integer cast to ``act_dtype``; + rows reading as 0 are either padding or a silently-broken peer read + (e.g. cross-rack MNNVL mapping that succeeded at construction but doesn't + actually back the peer's memory). Returns the per-rank decoded-sender + histogram for the caller to allgather and inspect. + """ + rank = mpi_rank() + min_tokens = _min_local_tokens_for_receiver_coverage(ep_size, top_k) + local_num_tokens = min_tokens if local_num_tokens is None else max(local_num_tokens, min_tokens) + all_rank_num_tokens = mpi_allgather(int(local_num_tokens)) + if not backend.is_workload_feasible(all_rank_num_tokens, num_chunks=1): + return {"rank": rank, "skipped": True} + + sentinel = float(rank + 1) + hidden_states = torch.full( + (local_num_tokens, hidden_size), + sentinel, + dtype=act_dtype, + device=device, + ) + flat_slots = torch.arange(local_num_tokens * top_k, device=device, dtype=torch.int64) + schedule = flat_slots + rank + target_rank = schedule % ep_size + local_expert = (schedule // ep_size) % experts_per_rank + token_selected_slots = ( + (target_rank * experts_per_rank + local_expert) + .view(local_num_tokens, top_k) + .to(torch.int32) + ) + token_final_scales = torch.ones( + local_num_tokens, + top_k, + dtype=torch.float32, + device=device, + ) + + backend.prepare_dispatch(token_selected_slots, all_rank_num_tokens) + recv_hs, _, _, _ = backend.dispatch( + hidden_states, + None, + token_selected_slots, + token_final_scales, + all_rank_num_tokens, + ) + # Pair dispatch with a combine so backend state mirrors the bench's + # warmup->timing call pattern (NCCL_EP especially relies on this). + shape = list(recv_hs.shape) + shape[-1] = hidden_size + moe_out = torch.zeros(tuple(shape), dtype=torch.bfloat16, device=recv_hs.device) + backend.combine(moe_out, all_rank_max_num_tokens=max(all_rank_num_tokens)) + torch.cuda.synchronize() + + first_col = recv_hs[:, 0].to(torch.float32) + decoded = first_col.round().to(torch.int64) + unique, counts = decoded.unique(return_counts=True) + histogram: Dict[int, int] = { + int(u) - 1: int(c) for u, c in zip(unique.tolist(), counts.tolist(), strict=True) + } + return {"rank": rank, "histogram": histogram} + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Unified MoE communication microbenchmark (MPI).") parser.add_argument( @@ -843,6 +938,7 @@ def parse_args() -> argparse.Namespace: "NVLINK_TWO_SIDED", "DEEPEP", "DEEPEPLOWLATENCY", + "NCCL_EP", ], help="Which communication backend to benchmark (default: run all backends).", ) @@ -951,6 +1047,16 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Disable CUDA graph mode. By default, dispatch and combine are captured into CUDA graphs for lower CPU overhead and more accurate timing.", ) + parser.add_argument( + "--verify", + action="store_true", + help=( + "Run a single sentinel dispatch per backend before timing and print a " + "receiver/sender contribution matrix. Detects silent cross-rack " + "correctness failures where dispatch appears to succeed but produces " + "zeros or local-only data." + ), + ) parser.add_argument( "--pdl", action="store_true", @@ -1050,6 +1156,10 @@ def _run_benchmark_worker_under_current_mpi( hidden_size, top_k, num_experts_total, quant_algo = _resolve_profile_args(args) local_batch_sizes = _iter_local_batch_sizes(args) + if args.verify: + local_batch_sizes = _scale_local_batch_sizes_for_receiver_coverage( + local_batch_sizes, ep_size, top_k + ) act_dtype = torch.bfloat16 quant_config = ( QuantConfig(quant_algo=None) @@ -1098,7 +1208,14 @@ def _run_benchmark_worker_under_current_mpi( print(json.dumps(benchmark_metadata, indent=2), flush=True) backends = ( - ["ALLGATHER", "NVLINK_ONE_SIDED", "NVLINK_TWO_SIDED", "DEEPEP", "DEEPEPLOWLATENCY"] + [ + "ALLGATHER", + "NVLINK_ONE_SIDED", + "NVLINK_TWO_SIDED", + "DEEPEP", + "DEEPEPLOWLATENCY", + "NCCL_EP", + ] if args.backend is None else [args.backend] ) @@ -1164,6 +1281,58 @@ def _run_benchmark_worker_under_current_mpi( # Ensure quantization params (e.g., NVFP4 global scale) live on CUDA. moe = moe.to(device) + if args.verify: + verify_local = _verify_dispatch_sentinel( + backend, + hidden_size=hidden_size, + top_k=top_k, + experts_per_rank=experts_per_rank, + ep_size=ep_size, + act_dtype=act_dtype, + device=device, + local_num_tokens=local_batch_sizes[0], + ) + all_verify = mpi_allgather(verify_local) + # Pass criterion: every receiver must have at least one token from + # every sender [0, ep_size). The verify local_num_tokens is scaled + # so local_num_tokens * top_k covers every receiver; + # any zero-column means the recv buffer was silently dropped from + # that sender. + verify_failed = False + for entry in all_verify: + if entry.get("skipped"): + verify_failed = True + break + hist = entry.get("histogram", {}) + if any(hist.get(s, 0) == 0 for s in range(ep_size)): + verify_failed = True + break + if rank == 0: + status = "FAIL" if verify_failed else "PASS" + print( + f"=== [verify] {backend_name} {status} -- sender->receiver " + f"contribution (rows=receiver, cols=sender; -1 col = " + f"padding/unmapped) ===", + flush=True, + ) + cols = [-1, *range(ep_size)] + header = "R\\S | " + " ".join(f"{c:>5}" for c in cols) + " | total" + print(header) + for entry in sorted(all_verify, key=lambda e: e.get("rank", -1)): + r = entry.get("rank") + if entry.get("skipped"): + print(f"{r:>3} | skipped (workload not feasible at verify size)") + continue + hist = entry.get("histogram", {}) + cells = " ".join(f"{hist.get(c, 0):>5}" for c in cols) + print(f"{r:>3} | {cells} | {sum(hist.values()):>5}") + sys.stdout.flush() + if verify_failed: + _maybe_warn_rank0( + f"[bench_moe_comm] Skipping timing for {backend_name}: verify FAILED." + ) + continue + for local_num_tokens in local_batch_sizes: all_rank_num_tokens = mpi_allgather(int(local_num_tokens)) if not backend.is_workload_feasible(all_rank_num_tokens, num_chunks=1): diff --git a/tests/unittest/_torch/modules/moe/test_communication_factory.py b/tests/unittest/_torch/modules/moe/test_communication_factory.py new file mode 100644 index 000000000000..d3d0ef5517bf --- /dev/null +++ b/tests/unittest/_torch/modules/moe/test_communication_factory.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from types import SimpleNamespace + +import pytest +import torch + +from tensorrt_llm._torch.modules.fused_moe import nccl_ep_utils +from tensorrt_llm._torch.modules.fused_moe.communication import communication_factory +from tensorrt_llm._torch.modules.fused_moe.communication.allgather_reducescatter import ( + AllGatherReduceScatter, +) +from tensorrt_llm._torch.modules.fused_moe.communication.nccl_ep import NcclEP + + +def _make_model_config( + act_dtype: torch.dtype = torch.bfloat16, + moe_max_num_tokens: int | None = 1024, +): + mapping = SimpleNamespace( + enable_attention_dp=True, + dp_size=2, + moe_tp_size=1, + moe_ep_size=2, + moe_ep_rank=0, + ) + return SimpleNamespace( + mapping=mapping, + pretrained_config=SimpleNamespace(hidden_size=4096), + torch_dtype=act_dtype, + quant_config=None, + max_num_tokens=1024, + moe_max_num_tokens=moe_max_num_tokens, + use_cuda_graph=False, + use_low_precision_moe_combine=False, + moe_load_balancer=None, + ) + + +def _strategy_unavailable(*args, **kwargs): + raise RuntimeError("strategy unavailable") + + +def _install_failing_nccl_module(monkeypatch: pytest.MonkeyPatch, error: BaseException): + def fail_get_version(): + raise error + + monkeypatch.setattr(nccl_ep_utils, "_nccl_ep_installed", None) + monkeypatch.setitem(sys.modules, "nccl", SimpleNamespace(get_version=fail_get_version)) + monkeypatch.delitem(sys.modules, "nccl.ep", raising=False) + + +def test_nccl_ep_installed_handles_runtime_probe_failure(monkeypatch: pytest.MonkeyPatch): + _install_failing_nccl_module(monkeypatch, RuntimeError("missing libnccl_ep")) + + assert nccl_ep_utils.is_nccl_ep_installed() is False + assert nccl_ep_utils._nccl_ep_installed is False + + +class _FakeNcclEP: + def __init__( + self, + mapping, + num_slots, + hidden_size, + max_num_tokens, + moe_max_num_tokens, + top_k=8, + ): + self.mapping = mapping + self.num_slots = num_slots + self.hidden_size = hidden_size + self.max_num_tokens = max_num_tokens + self.moe_max_num_tokens = moe_max_num_tokens + self.top_k = top_k + + +@pytest.mark.parametrize( + ("act_dtype", "moe_max_num_tokens", "match"), + [ + (torch.float16, 1024, "act_dtype=torch.bfloat16"), + ], +) +def test_forced_nccl_ep_validates_preconditions( + act_dtype: torch.dtype, + moe_max_num_tokens: int | None, + match: str, +): + model_config = _make_model_config(act_dtype, moe_max_num_tokens) + + with pytest.raises(ValueError, match=match): + communication_factory.CommunicationFactory._create_forced_method( + "NCCL_EP", + model_config, + num_experts=32, + num_slots=32, + top_k=8, + expert_size_per_partition=16, + payload_in_workspace=False, + alltoall_result_do_sum=True, + use_flashinfer=False, + hidden_size=4096, + ) + + +def test_forced_nccl_ep_allows_missing_moe_max_num_tokens( + monkeypatch: pytest.MonkeyPatch, +): + model_config = _make_model_config(torch.bfloat16, None) + monkeypatch.setattr(communication_factory, "NcclEP", _FakeNcclEP) + + strategy = communication_factory.CommunicationFactory._create_forced_method( + "NCCL_EP", + model_config, + num_experts=32, + num_slots=32, + top_k=8, + expert_size_per_partition=16, + payload_in_workspace=False, + alltoall_result_do_sum=True, + use_flashinfer=False, + hidden_size=4096, + ) + + assert isinstance(strategy, _FakeNcclEP) + assert strategy.max_num_tokens == model_config.max_num_tokens + assert strategy.moe_max_num_tokens is None + + +def test_auto_selection_uses_nccl_ep_with_missing_moe_max_num_tokens( + monkeypatch: pytest.MonkeyPatch, +): + model_config = _make_model_config(torch.bfloat16, None) + + monkeypatch.setattr(communication_factory, "NVLinkOneSided", _strategy_unavailable) + monkeypatch.setattr(communication_factory, "NVLinkTwoSided", _strategy_unavailable) + monkeypatch.setenv("TRTLLM_CAN_USE_DEEP_EP", "0") + monkeypatch.setattr(communication_factory, "NcclEP", _FakeNcclEP) + + strategy = communication_factory.CommunicationFactory.create_strategy( + model_config, + num_experts=32, + num_slots=32, + top_k=8, + expert_size_per_partition=16, + hidden_size=4096, + ) + + assert isinstance(strategy, _FakeNcclEP) + assert strategy.max_num_tokens == model_config.max_num_tokens + assert strategy.moe_max_num_tokens is None + + +@pytest.mark.parametrize( + ("act_dtype", "moe_max_num_tokens"), + [ + (torch.float16, 1024), + ], +) +def test_auto_selection_skips_nccl_ep_when_preconditions_fail( + monkeypatch: pytest.MonkeyPatch, + act_dtype: torch.dtype, + moe_max_num_tokens: int | None, +): + model_config = _make_model_config(act_dtype, moe_max_num_tokens) + + monkeypatch.setattr(communication_factory, "NVLinkOneSided", _strategy_unavailable) + monkeypatch.setattr(communication_factory, "NVLinkTwoSided", _strategy_unavailable) + monkeypatch.setenv("TRTLLM_CAN_USE_DEEP_EP", "0") + + def fail_if_called(*args, **kwargs): + raise AssertionError("NcclEP should not be constructed") + + monkeypatch.setattr(communication_factory, "NcclEP", fail_if_called) + + strategy = communication_factory.CommunicationFactory.create_strategy( + model_config, + num_experts=32, + num_slots=32, + top_k=8, + expert_size_per_partition=16, + hidden_size=4096, + ) + + assert isinstance(strategy, AllGatherReduceScatter) + + +def test_auto_selection_falls_back_when_nccl_probe_runtime_fails( + monkeypatch: pytest.MonkeyPatch, +): + model_config = _make_model_config(torch.bfloat16, None) + monkeypatch.setattr(communication_factory, "NVLinkOneSided", _strategy_unavailable) + monkeypatch.setattr(communication_factory, "NVLinkTwoSided", _strategy_unavailable) + monkeypatch.setenv("TRTLLM_CAN_USE_DEEP_EP", "0") + _install_failing_nccl_module(monkeypatch, OSError("missing native NCCL EP library")) + + strategy = communication_factory.CommunicationFactory.create_strategy( + model_config, + num_experts=32, + num_slots=32, + top_k=8, + expert_size_per_partition=16, + hidden_size=4096, + ) + + assert isinstance(strategy, AllGatherReduceScatter) + + +def test_nccl_ep_context_init_rejects_cuda_graph_capture( + monkeypatch: pytest.MonkeyPatch, +): + strategy = object.__new__(NcclEP) + strategy._ctx = None + monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True) + + with pytest.raises(RuntimeError, match="context must be initialized before CUDA graph capture"): + strategy._get_context() + + +def test_nccl_ep_handle_init_rejects_cuda_graph_capture( + monkeypatch: pytest.MonkeyPatch, +): + strategy = object.__new__(NcclEP) + strategy._handle = None + monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True) + + def fail_create_handle(*args, **kwargs): + raise AssertionError("create_handle should not run during CUDA graph capture") + + ctx = SimpleNamespace( + ep_group=SimpleNamespace(create_handle=fail_create_handle), + layout=object(), + ) + + with pytest.raises( + RuntimeError, match="dispatch handle must be initialized before CUDA graph capture" + ): + strategy._setup_handle(ctx, object(), 0) diff --git a/tests/unittest/_torch/modules/moe/test_moe_comm.py b/tests/unittest/_torch/modules/moe/test_moe_comm.py index c59751f42025..12484d714e7b 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_comm.py +++ b/tests/unittest/_torch/modules/moe/test_moe_comm.py @@ -67,12 +67,14 @@ ) from tensorrt_llm._torch.modules.fused_moe.communication.deep_ep import DeepEP from tensorrt_llm._torch.modules.fused_moe.communication.deep_ep_low_latency import DeepEPLowLatency +from tensorrt_llm._torch.modules.fused_moe.communication.nccl_ep import NcclEP from tensorrt_llm._torch.modules.fused_moe.communication.nvlink_one_sided import NVLinkOneSided from tensorrt_llm._torch.modules.fused_moe.communication.nvlink_two_sided import NVLinkTwoSided from tensorrt_llm._torch.modules.fused_moe.communication.nvlink_two_sided_flashinfer import ( NVLinkTwoSidedFlashinfer, ) from tensorrt_llm._torch.modules.fused_moe.deep_ep_utils import deep_ep_installed +from tensorrt_llm._torch.modules.fused_moe.nccl_ep_utils import is_nccl_ep_installed from tensorrt_llm.deep_ep.buffer import Buffer from tensorrt_llm.mapping import Mapping @@ -93,6 +95,7 @@ COMM_NVLINK_ONE_SIDED = "NVLinkOneSided" COMM_NVLINK_TWO_SIDED = "NVLinkTwoSided" COMM_NVLINK_TWO_SIDED_FLASHINFER = "NVLinkTwoSidedFlashinfer" +COMM_NCCL_EP = "NcclEP" ALL_COMM_TYPES = [ COMM_ALLGATHER_RS, @@ -101,6 +104,7 @@ COMM_NVLINK_ONE_SIDED, COMM_NVLINK_TWO_SIDED, COMM_NVLINK_TWO_SIDED_FLASHINFER, + COMM_NCCL_EP, ] # Must be in DeepEPLowLatency.SUPPORTED_HIDDEN_SIZES @@ -393,6 +397,16 @@ def create_comm_object( alltoall_result_do_sum=True, ) + elif comm_type == COMM_NCCL_EP: + return NcclEP( + mapping=mapping, + num_slots=num_slots, + hidden_size=config.hidden_size, + max_num_tokens=max_num_tokens, + moe_max_num_tokens=max_num_tokens, + top_k=config.top_k, + ) + else: raise ValueError(f"Unknown comm type: {comm_type}") @@ -442,6 +456,11 @@ def check_platform_support(comm_type: str) -> Optional[str]: if comm_type == COMM_NVLINK_TWO_SIDED_FLASHINFER: return _check_flashinfer_mnnvl_support() + if comm_type == COMM_NCCL_EP: + if not is_nccl_ep_installed(): + return "NCCL EP not available (install the nccl4py wheel)" + return None + return f"Unknown comm type: {comm_type}" @@ -492,6 +511,10 @@ def check_feasibility(comm_type: str, config: CommTestConfig) -> Optional[str]: if config.top_k > NVLinkOneSided.MAX_TOP_K: return f"NVLinkOneSided MAX_TOP_K={NVLinkOneSided.MAX_TOP_K}, got top_k={config.top_k}" + if comm_type == COMM_NCCL_EP: + if config.quant_mode != "none": + return f"NcclEP does not support quant_mode={config.quant_mode}" + if comm_type == COMM_NVLINK_TWO_SIDED_FLASHINFER: # FlashInfer alltoallv requires every 2D payload row to be 16-byte aligned. # This test dispatches both int32 slots [N, top_k] and bf16 scales @@ -1247,14 +1270,27 @@ def _build_combine_reference( ref[token_idx] += nvfp4_out[i].float() elif config.comm_type == COMM_DEEP_EP_LL: - # Path 2: DeepEPLL weighted reduction. - # DeepEPLL's dispatch returns ones as recv_scales, so simple_moe - # produces unweighted output. The combine kernel (low_latency_combine) - # internally applies real topk_weights during weighted reduction. - # Reconstruct by weighting recv_hs_bf16 per local expert with the - # real weights from original_scales on the source rank. + # Path 2: DeepEPLL weighted reduction (expert-major output). + # + # Dispatch output is [num_local_experts * ep_size * max_tokens, hidden]. + # `simple_moe` returns identity (recv_scales are ones), so each row + # passed to combine still holds the original source token's bytes. + # The combine kernel sends per-position slots (one per k in top_k) and + # applies weight[k] to the k-th slot -- duplicate experts are NOT deduped; + # each position contributes independently. + # + # Empirically confirmed on DeepEPLL (4-rank LL, k=2): for every + # token T on target_rank: + # combined[T] = (sum over k in [0, top_k) of w[T, k]) * T_hidden + # regardless of whether all top_k experts are distinct or a subset + # coalesces onto the same expert. + # + # Find the source token's hidden_states by locating any received row + # whose source is (target_rank, T); all such rows carry identical bytes. target_original_scales = all_results[target_rank]["original_scales"] + # Locate a representative received row per (target_rank, token_idx). + token_hs: dict = {} # token_idx -> bf16 row for proc_result in all_results: proc_rank = proc_result["rank"] recv_hs_bf16 = proc_result["recv_hs_bf16"] @@ -1265,12 +1301,17 @@ def _build_combine_reference( ) for i, (src_rank, token_idx) in enumerate(source_info): - if src_rank == target_rank and token_idx < num_tokens: - for k in range(config.top_k): - eid = recv_slots[i, k].item() - if slot_start <= eid < slot_end: - weight = target_original_scales[token_idx, k].float() - ref[token_idx] += recv_hs_bf16[i].float() * weight + if src_rank != target_rank or token_idx >= num_tokens: + continue + if not any(slot_start <= eid < slot_end for eid in recv_slots[i].tolist()): + continue + if token_idx not in token_hs: + token_hs[token_idx] = recv_hs_bf16[i] + + # Sum of topk weights per token applied to that token's hidden_states. + for token_idx, hs in token_hs.items(): + weight_sum = target_original_scales[token_idx].float().sum() + ref[token_idx] += hs.float() * weight_sum else: # Path 3: Default — float32 accumulation.