diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 91cb5725fede..0f2d453c363a 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -210,6 +210,14 @@ __device__ __forceinline__ int compute_target_rank_id(int expert_id, int base, i return remainder + (expert_id - split) / base; } +// Test bit `rank` in a kRankMaskWords-wide little-endian uint64 bitmask. +// Word 0 covers ranks 0..63, word 1 covers ranks 64..127, etc. +// `rank >> 6` and `rank & 63` divide / modulo by 64. +__device__ __forceinline__ bool is_rank_active(uint64_t const* mask, int rank) +{ + return (mask[rank >> 6] >> (rank & 63)) & 1ULL; +} + // ============================================================================ // Helper Functions for Vectorized Memory Operations // ============================================================================ @@ -416,7 +424,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ int* smem_topk_target_ranks = smem; int* smem_topk_send_indices = smem + TOP_K; - uint64_t already_copied = 0; + uint64_t already_copied[kRankMaskWords] = {}; // Precompute the ceil/floor partition parameters once per thread, outside the // per-token TOP_K loop. The fast path (remainder == 0) then collapses to a single // integer divide per call, matching the pre-PR uniform-partition cost exactly. @@ -432,7 +440,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ // Supports the non-divisible case where num_experts % ep_size != 0. int target_rank = compute_target_rank_id(expert_id, ep_base, ep_remainder); - if (already_copied & (1ULL << target_rank)) + // Skip duplicates AND dead ranks: both produce the same -1 sentinel that combine + // checks via topk_send_indices[k] < 0. A token whose only target is dead is dropped + // from this collective; higher-layer logic (EPLB redistribution) is responsible + // for re-routing such tokens on subsequent iterations. + int const mask_word = target_rank >> 6; + uint64_t const mask_bit = 1ULL << (target_rank & 63); + bool const target_already_copied = already_copied[mask_word] & mask_bit; + bool const target_dead = !is_rank_active(ptrs.active_rank_mask, target_rank); + if (target_already_copied || target_dead) { if (thread_idx == 0) { @@ -457,7 +473,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ smem_topk_target_ranks[k] = target_rank; smem_topk_send_indices[k] = dst_token_idx; } - already_copied |= 1ULL << target_rank; + already_copied[mask_word] |= mask_bit; } // Sync before dispatching data ThreadingPolicy::sync(); @@ -511,10 +527,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ if (is_last_token) { -// Store send_counters to recv_counters +// Store send_counters to recv_counters. +// Skip masked target ranks: their symmetric memory may be inaccessible. #pragma unroll 1 // No unroll as one iter is typically enough for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { + if (!is_rank_active(ptrs.active_rank_mask, target_rank)) + continue; int send_count = ptrs.send_counters[target_rank]; ptrs.recv_counters[target_rank][rank_id] = send_count; } @@ -522,9 +541,12 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ if constexpr (ENABLE_EPLB) { // Write local stats into peer buffers before the release fence below. + // Skip masked target ranks for the same reason as above. #pragma unroll 1 for (int target_rank = 0; target_rank < ep_size; ++target_rank) { + if (!is_rank_active(ptrs.active_rank_mask, target_rank)) + continue; int* target_stats = ptrs.eplb_gathered_stats[target_rank]; for (int expert_id = lane_id; expert_id < eplb_stats_num_experts; expert_id += warpSize) { @@ -543,9 +565,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ #else asm volatile("fence.acq_rel.sys;"); #endif + // Signal completion to all active peers; skip dead ranks (their symmetric memory + // is unreachable). #pragma unroll 1 // No unroll as one iter is typically enough for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { + if (!is_rank_active(ptrs.active_rank_mask, target_rank)) + continue; uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id]; asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); @@ -555,9 +581,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ #endif } + // Wait for all active peers to signal; skip dead ranks (otherwise we would + // spin forever — this is the bug the rank-mask is here to prevent). #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + if (!is_rank_active(ptrs.active_rank_mask, peer_rank)) + continue; bool flag_set = false; auto s = clock64(); do @@ -603,8 +633,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) // Validate parameters TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); + TLLM_CHECK(params.ep_rank >= 0 && params.ep_rank < params.ep_size); TLLM_CHECK(params.local_num_tokens >= 0); TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads); + // The local rank must always be marked active in its own view of the mask; + // otherwise the kernel itself would be running on a "dead" rank. + TLLM_CHECK_WITH_INFO((params.active_rank_mask[params.ep_rank >> 6] >> (params.ep_rank & 63)) & 1ULL, + "active_rank_mask must mark the local ep_rank (%d) as active", params.ep_rank); // Prepare kernel pointers struct DispatchKernelPointers kernel_ptrs = {}; @@ -642,6 +677,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) kernel_ptrs.topk_send_indices = params.topk_send_indices; kernel_ptrs.eplb_local_stats = params.eplb_local_stats; + // Copy active-rank bitmask into the kernel pointers struct + for (int w = 0; w < kRankMaskWords; ++w) + { + kernel_ptrs.active_rank_mask[w] = params.active_rank_mask[w]; + } + int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize(); // One block per token: grid_size == local_num_tokens. If 0, launch a single block to @@ -700,7 +741,7 @@ __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, i { int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k]; int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k]; - if (dst_idx < 0) + if (dst_idx < 0 || !is_rank_active(ptrs.active_rank_mask, target_rank)) { acc[k].fill(0.0f); continue; @@ -725,8 +766,12 @@ __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, i #pragma unroll for (int k = 0; k < TOP_K; ++k) { - if (ptrs.topk_send_indices[local_token_idx * TOP_K + k] < 0) + int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k]; + int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k]; + if (dst_idx < 0 || !is_rank_active(ptrs.active_rank_mask, target_rank)) + { continue; // acc[k] already holds 0.0f from fill() above + } #pragma unroll for (int j = elems_per_vec - 1; j >= 0; --j) acc[k][j] = static_cast(reinterpret_cast(&acc[k])[j]); @@ -1153,9 +1198,13 @@ __global__ void moeA2ACombineKernel( if (blockIdx.x == 0) { + // Signal readiness to all active peers; skip dead ranks (their symmetric memory + // is unreachable). #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + if (!is_rank_active(ptrs.active_rank_mask, peer_rank)) + continue; uint32_t* flag_addr = &ptrs.completion_flags[peer_rank][rank_id]; asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); #if ENABLE_DEBUG_PRINT @@ -1165,9 +1214,13 @@ __global__ void moeA2ACombineKernel( } } + // Wait for all active peers to signal; skip dead ranks (otherwise we would spin + // forever — this is the bug the rank-mask is here to prevent). #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + if (!is_rank_active(ptrs.active_rank_mask, peer_rank)) + continue; bool flag_set = false; auto s = clock64(); do @@ -1271,8 +1324,13 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) // Validate parameters TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); + TLLM_CHECK(params.ep_rank >= 0 && params.ep_rank < params.ep_size); TLLM_CHECK(params.local_num_tokens >= 0); TLLM_CHECK(params.elements_per_token > 0); + // The local rank must always be marked active in its own view of the mask; + // otherwise the kernel itself would be running on a "dead" rank. + TLLM_CHECK_WITH_INFO((params.active_rank_mask[params.ep_rank >> 6] >> (params.ep_rank & 63)) & 1ULL, + "active_rank_mask must mark the local ep_rank (%d) as active", params.ep_rank); // Configure kernel launch (one block per token). int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize(); @@ -1306,6 +1364,12 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) kernel_ptrs.topk_target_ranks = params.topk_target_ranks; kernel_ptrs.topk_send_indices = params.topk_send_indices; + // Copy active-rank bitmask into the kernel pointers struct + for (int w = 0; w < kRankMaskWords; ++w) + { + kernel_ptrs.active_rank_mask[w] = params.active_rank_mask[w]; + } + // stride_per_token: byte distance between tokens in the recv buffer. // FP8 external payload: EPT × 1 (compact FP8 layout) // FP8 in-place / non-FP8: EPT × sizeof(PayloadT) (payload-dtype stride) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h index 317ff4d2240c..138ca92e71a8 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -26,9 +26,12 @@ namespace kernels::moe_comm { // Configuration constants -static constexpr int kMaxTopK = 22; // Maximum top-k experts per token -static constexpr int kMaxPayloads = 4; // Maximum number of different payload types -static constexpr int kMaxRanks = 64; // Maximum supported EP size +static constexpr int kMaxTopK = 22; // Maximum top-k experts per token +static constexpr int kMaxPayloads = 4; // Maximum number of different payload types +static constexpr int kMaxRanks = 128; // Maximum supported EP size (covers NVL72 with headroom) +static constexpr int kRankMaskWords = 2; // uint64 words to hold the active-rank bitmask + // (kRankMaskWords * 64 must be >= kMaxRanks) +static_assert(kRankMaskWords * 64 >= kMaxRanks, "active_rank_mask too small for kMaxRanks"); // Describes a single payload type to be communicated struct PayloadDescriptor @@ -65,6 +68,12 @@ struct DispatchKernelPointers // Optional: Statistics for EPLB int const* eplb_local_stats; // [eplb_stats_num_experts] int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank + + // Active-rank bitmask: bit i set => rank i is alive and participates in this collective. + // Word 0 covers ranks 0..63; word 1 covers ranks 64..127. Tokens routed to a masked + // rank are dropped (topk_*[k] = -1); flag writes/waits to/from masked peers are skipped. + // The local rank's own bit must always be set; this is checked at launch time. + uint64_t active_rank_mask[kRankMaskWords]; }; // Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers @@ -82,6 +91,11 @@ struct CombineKernelPointers // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) int const* topk_target_ranks; // target rank per k, -1 for duplicates int const* topk_send_indices; // dst index per k, -1 for duplicates + + // Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. Combine skips flag + // writes/waits to/from masked peers and also skips per-token accumulation for ranks that + // become inactive between dispatch and combine. + uint64_t active_rank_mask[kRankMaskWords]; }; // Dispatch phase parameters @@ -125,6 +139,11 @@ struct MoeA2ADispatchParams int const* eplb_local_stats; // [eplb_stats_num_experts] int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank + // Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function + // copies these words into the kernel pointers struct. Defaults to all-ones for + // backwards-compatible "no masking" behavior. + uint64_t active_rank_mask[kRankMaskWords] = {~uint64_t{0}, ~uint64_t{0}}; + // CUDA stream cudaStream_t stream; }; @@ -170,6 +189,11 @@ struct MoeA2ACombineParams // rank has signaled the target rank void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) + // Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function + // copies these words into the kernel pointers struct. Defaults to all-ones for + // backwards-compatible "no masking" behavior. + uint64_t active_rank_mask[kRankMaskWords] = {~uint64_t{0}, ~uint64_t{0}}; + // CUDA stream cudaStream_t stream; }; diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index 7a767976dabd..fc45afd792bb 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -42,6 +42,43 @@ inline size_t alignOffset(size_t offset, size_t alignment) return (offset + alignment - 1) & ~(alignment - 1); } +// Resolve an optional rank-mask tensor into a fixed-width uint64 array. +// If the caller did not provide a mask, default to "all ranks active" (all bits set), which +// reproduces the pre-fault-tolerance behavior bit-for-bit. +// +// On failure (wrong dtype / device / shape), throws via TORCH_CHECK so the error surfaces +// at the Python op boundary rather than the kernel launch. +inline void resolveActiveRankMask(torch::optional const& maskTensor, int64_t epRank, + uint64_t (&out)[tensorrt_llm::kernels::moe_comm::kRankMaskWords]) +{ + using tensorrt_llm::kernels::moe_comm::kRankMaskWords; + using tensorrt_llm::kernels::moe_comm::kMaxRanks; + TORCH_CHECK( + epRank >= 0 && epRank < kMaxRanks, "epRank must be in the range [0, ", kMaxRanks, ") for active_rank_mask"); + if (!maskTensor.has_value() || !maskTensor.value().defined()) + { + for (int w = 0; w < kRankMaskWords; ++w) + { + out[w] = ~uint64_t{0}; + } + return; + } + torch::Tensor const& t = maskTensor.value(); + TORCH_CHECK(t.is_cpu(), "active_rank_mask must be a CPU tensor"); + TORCH_CHECK(t.scalar_type() == torch::kUInt64, "active_rank_mask must have dtype uint64"); + TORCH_CHECK(t.dim() == 1, "active_rank_mask must be a 1D tensor"); + TORCH_CHECK(t.numel() == kRankMaskWords, "active_rank_mask must have exactly ", kRankMaskWords, " uint64 elements"); + TORCH_CHECK(t.is_contiguous(), "active_rank_mask must be contiguous"); + auto const* src = static_cast(t.const_data_ptr()); + for (int w = 0; w < kRankMaskWords; ++w) + { + out[w] = src[w]; + } + // Local rank's bit must be set; otherwise the kernel would be running on a "dead" rank. + TORCH_CHECK((out[epRank >> 6] >> (epRank & 63)) & 1ULL, "active_rank_mask must mark the local ep_rank (", epRank, + ") as active"); +} + // Calculate auxiliary data offsets MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNumExperts) { @@ -117,11 +154,14 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens, torch::optional eplbStatsNumExperts) { + using tensorrt_llm::kernels::moe_comm::kMaxRanks; + // Validate inputs CHECK_TH_CUDA(workspace); CHECK_TYPE(workspace, torch::kUInt8); TORCH_CHECK(workspace.dim() == 2, "workspace must be a 2D tensor of shape [epSize, sizePerRank]"); TORCH_CHECK(workspace.size(0) == epSize, "workspace first dimension must equal epSize"); + TORCH_CHECK(epSize > 0 && epSize <= kMaxRanks, "epSize must be in the range (0, ", kMaxRanks, "]"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); // Initialize workspace to zero @@ -181,13 +221,15 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( torch::Tensor const& tokenSelectedExperts, std::vector const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, - int64_t epSize, int64_t topK, int64_t numExperts, torch::optional eplbLocalStats) + int64_t epSize, int64_t topK, int64_t numExperts, torch::optional eplbLocalStats, + torch::optional activeRankMask) { using tensorrt_llm::kernels::moe_comm::PayloadDescriptor; using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams; using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch; using tensorrt_llm::kernels::moe_comm::kMaxTopK; using tensorrt_llm::kernels::moe_comm::kMaxPayloads; + using tensorrt_llm::kernels::moe_comm::kMaxRanks; // Validate inputs CHECK_INPUT(tokenSelectedExperts, torch::kInt32); @@ -203,6 +245,7 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( int64_t localNumTokens = tokenSelectedExperts.size(0); TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive"); + TORCH_CHECK(epSize > 0 && epSize <= kMaxRanks, "epSize must be in the range (0, ", kMaxRanks, "]"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]"); TORCH_CHECK(!inputPayloads.empty(), "inputPayloads must not be empty"); @@ -360,6 +403,10 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( params.eplb_local_stats = nullptr; } + // Resolve the optional active-rank mask. Default (no mask) = all bits set, which + // exactly reproduces the pre-fault-tolerance kernel behavior. + resolveActiveRankMask(activeRankMask, epRank, params.active_rank_mask); + params.stream = at::cuda::getCurrentCUDAStream(); // Prepare for dispatch (zero counters/indices and increment flag_val) @@ -413,11 +460,13 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( // In both cases, the combine kernel reads from the workspace at 'combinePayloadOffset'. torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumTokens, torch::Tensor const& workspace, torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, - int64_t combinePayloadOffset, bool payloadInWorkspace, bool useLowPrecision = false) + int64_t combinePayloadOffset, bool payloadInWorkspace, bool useLowPrecision = false, + torch::optional activeRankMask = torch::nullopt) { using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams; using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch; using tensorrt_llm::kernels::moe_comm::kMaxTopK; + using tensorrt_llm::kernels::moe_comm::kMaxRanks; // Validate inputs CHECK_TH_CUDA(payload); @@ -431,6 +480,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke TORCH_CHECK(reinterpret_cast(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned"); int64_t elementsPerToken = payload.size(2); TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive"); + TORCH_CHECK(epSize > 0 && epSize <= kMaxRanks, "epSize must be in the range (0, ", kMaxRanks, "]"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]"); @@ -520,6 +570,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke params.recv_buffers[target_rank] = target_workspace_ptr + combinePayloadOffset; } + // Resolve the optional active-rank mask. Default (no mask) = all bits set. + resolveActiveRankMask(activeRankMask, epRank, params.active_rank_mask); + params.stream = at::cuda::getCurrentCUDAStream(); moe_a2a_prepare_combine_launch(params); @@ -613,12 +666,14 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module) "moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, " "Tensor(a!->*) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, " "int ep_rank, int ep_size, int top_k, int num_experts, " - "Tensor? eplb_local_stats=None) -> (Tensor(a!)[], int, Tensor(a!))"); + "Tensor? eplb_local_stats=None, " + "Tensor? active_rank_mask=None) -> (Tensor(a!)[], int, Tensor(a!))"); module.def( "moe_a2a_combine(Tensor(a) payload, int local_num_tokens," "Tensor(a!) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, " "int ep_rank, int ep_size, int top_k, int combine_payload_offset, " - "bool payload_in_workspace, bool use_low_precision=False) -> Tensor"); + "bool payload_in_workspace, bool use_low_precision=False, " + "Tensor? active_rank_mask=None) -> Tensor"); module.def( "moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank, " "int? eplb_stats_num_experts=None) -> Tensor"); diff --git a/tensorrt_llm/_torch/alltoall_watchdog.py b/tensorrt_llm/_torch/alltoall_watchdog.py new file mode 100644 index 000000000000..c41a3fb8fdb3 --- /dev/null +++ b/tensorrt_llm/_torch/alltoall_watchdog.py @@ -0,0 +1,463 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Host-side watchdog for MoE AlltoAll completion flags. + +The NVLinkOneSided kernels signal each collective by writing the current +``flag_val`` into the rank-local completion flag table. A dead peer in the +silent-spin failure mode never writes its slot, so this watchdog polls the same +table from a CPU thread and reports peers whose flags do not reach the expected +generation before a bounded timeout. +""" + +from __future__ import annotations + +import threading +import time +from collections import deque +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass +from typing import Protocol + +import torch + +from tensorrt_llm._utils import prefer_pinned +from tensorrt_llm.logger import logger as tllm_logger + +DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S = 5.0 +DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S = 0.1 +UNKNOWN_COMPLETION_FLAG = -(2**63) + + +class CompletionFlagReader(Protocol): + """Reads one phase's rank-local completion flag row.""" + + def read_completion_flags(self, phase: str) -> Sequence[int]: + """Return ``ep_size`` flag values for ``phase``.""" + + +class EPGroupHealthLike(Protocol): + """Subset of EPGroupHealth used by the watchdog.""" + + def get_mask(self) -> int: + """Return the active-rank bitmask.""" + + def mark_failed(self, rank: int) -> bool: + """Mark ``rank`` failed and return whether state changed.""" + + +class CompletionFlagReadTimeout(TimeoutError): + """Raised when the host watchdog cannot read completion flags in time.""" + + +@dataclass(frozen=True) +class AlltoAllWatchdogTimeout: + """Details emitted when an AlltoAll phase times out.""" + + phase: str + expected_flag: int + observed_flags: tuple[int, ...] + missing_ranks: tuple[int, ...] + marked_failed_ranks: tuple[int, ...] + elapsed_s: float + poll_timed_out: bool = False + + +@dataclass(frozen=True) +class _CollectiveWatch: + phase: str + expected_flag: int + active_mask: int + start_s: float + + +class _TorchCompletionFlagReader: + """Completion-flag reader backed by the MoE AlltoAll workspace tensor.""" + + def __init__( + self, + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + dispatch_completion_flags_offset: int, + combine_completion_flags_offset: int, + device_copy_timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + ) -> None: + if workspace.dim() != 2: + raise ValueError("workspace must be a 2D tensor [ep_size, size_per_rank]") + if not 0 <= ep_rank < ep_size: + raise ValueError(f"ep_rank must be in [0, {ep_size}), got {ep_rank}") + if workspace.size(0) != ep_size: + raise ValueError( + f"workspace first dimension must equal ep_size={ep_size}, got {workspace.size(0)}" + ) + self._workspace = workspace + self._ep_rank = ep_rank + self._ep_size = ep_size + self._offsets = { + "dispatch": int(dispatch_completion_flags_offset), + "combine": int(combine_completion_flags_offset), + } + self._device_copy_timeout_s = float(device_copy_timeout_s) + self._copy_stream: torch.cuda.Stream | None = None + self._host_flags: torch.Tensor | None = None + self._copy_event: torch.cuda.Event | None = None + self._retired_copies: list[tuple[torch.Tensor, torch.cuda.Event]] = [] + if workspace.device.type == "cuda": + self._copy_stream = torch.cuda.Stream(device=workspace.device) + + def _prune_retired_copies(self) -> None: + self._retired_copies = [ + (host_flags, event) for host_flags, event in self._retired_copies if not event.query() + ] + + def _read_cuda_flags(self, flags: torch.Tensor) -> tuple[int, ...]: + assert self._copy_stream is not None + self._prune_retired_copies() + + if self._host_flags is None: + self._host_flags = torch.empty( + (self._ep_size,), + dtype=torch.int32, + device="cpu", + pin_memory=prefer_pinned(), + ) + if self._copy_event is None: + self._copy_event = torch.cuda.Event(blocking=False) + host_flags = self._host_flags + event = self._copy_event + with torch.cuda.device(flags.device), torch.cuda.stream(self._copy_stream): + host_flags.copy_(flags.detach(), non_blocking=True) + event.record(self._copy_stream) + + deadline_s = time.monotonic() + self._device_copy_timeout_s + while not event.query(): + remaining_s = deadline_s - time.monotonic() + if remaining_s <= 0: + self._retired_copies.append((host_flags, event)) + self._host_flags = None + self._copy_event = None + raise CompletionFlagReadTimeout( + "timed out copying AlltoAll completion flags to host" + ) + time.sleep(min(remaining_s, 0.001)) + + return tuple(int(v) for v in host_flags.tolist()) + + def read_completion_flags(self, phase: str) -> tuple[int, ...]: + offset = self._offsets[phase] + end = offset + self._ep_size * 4 + flags = self._workspace[self._ep_rank, offset:end].view(torch.int32) + if flags.device.type == "cuda": + return self._read_cuda_flags(flags) + if flags.device.type != "cpu": + flags = flags.detach().cpu() + return tuple(int(v) for v in flags.tolist()) + + +class AlltoAllWatchdog: + """Background host thread that watches AlltoAll completion flags. + + The watchdog is intentionally opt-in. Callers queue phases with + :meth:`watch`; the thread polls them in FIFO order so a queued combine cannot + hide a still-spinning dispatch. + """ + + VALID_PHASES = frozenset({"dispatch", "combine"}) + + def __init__( + self, + *, + ep_size: int, + ep_rank: int, + completion_reader: CompletionFlagReader, + timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, + poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + health: EPGroupHealthLike | None = None, + on_timeout: Callable[[AlltoAllWatchdogTimeout], None] | None = None, + ) -> None: + if ep_size <= 0: + raise ValueError(f"ep_size must be > 0, got {ep_size}") + if not 0 <= ep_rank < ep_size: + raise ValueError(f"ep_rank must be in [0, {ep_size}), got {ep_rank}") + if timeout_s <= 0: + raise ValueError(f"timeout_s must be > 0, got {timeout_s}") + if poll_interval_s <= 0: + raise ValueError(f"poll_interval_s must be > 0, got {poll_interval_s}") + + self._ep_size = int(ep_size) + self._ep_rank = int(ep_rank) + self._completion_reader = completion_reader + self._timeout_s = float(timeout_s) + self._poll_interval_s = float(poll_interval_s) + self._health = health + self._on_timeout = on_timeout + + self._cv = threading.Condition() + self._queue: deque[_CollectiveWatch] = deque() + self._closed = False + self._stopping = False + self._thread: threading.Thread | None = None + self._last_error: BaseException | None = None + + @classmethod + def from_workspace( + cls, + *, + workspace: torch.Tensor, + metainfo: torch.Tensor, + metainfo_index: Mapping[str, int], + ep_rank: int, + ep_size: int, + timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, + poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + health: EPGroupHealthLike | None = None, + on_timeout: Callable[[AlltoAllWatchdogTimeout], None] | None = None, + ) -> "AlltoAllWatchdog": + """Build a watchdog from the MoE AlltoAll workspace and metainfo.""" + dispatch_offset = int( + metainfo[metainfo_index["DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX"]].item() + ) + combine_offset = int( + metainfo[metainfo_index["COMBINE_COMPLETION_FLAGS_OFFSET_INDEX"]].item() + ) + reader = _TorchCompletionFlagReader( + workspace=workspace, + ep_rank=ep_rank, + ep_size=ep_size, + dispatch_completion_flags_offset=dispatch_offset, + combine_completion_flags_offset=combine_offset, + device_copy_timeout_s=poll_interval_s, + ) + return cls( + ep_size=ep_size, + ep_rank=ep_rank, + completion_reader=reader, + timeout_s=timeout_s, + poll_interval_s=poll_interval_s, + health=health, + on_timeout=on_timeout, + ) + + @property + def last_error(self) -> BaseException | None: + """Return the last polling-thread error, if any.""" + with self._cv: + return self._last_error + + def start(self) -> None: + """Start the background polling thread. Idempotent.""" + with self._cv: + if self._closed: + raise RuntimeError("cannot start a stopped AlltoAllWatchdog") + if self._thread is not None and self._thread.is_alive(): + return + self._stopping = False + self._thread = threading.Thread( + target=self._run, + name=f"AlltoAllWatchdog-rank{self._ep_rank}", + daemon=True, + ) + self._thread.start() + + def stop(self, timeout_s: float | None = None) -> None: + """Stop the polling thread and wait for it to exit.""" + with self._cv: + self._closed = True + self._stopping = True + self._queue.clear() + self._cv.notify_all() + thread = self._thread + if thread is not None: + thread.join(timeout=timeout_s) + + def watch( + self, + *, + phase: str, + expected_flag: int, + active_mask: int | None = None, + ) -> None: + """Queue a just-launched AlltoAll phase for watchdog polling.""" + if phase not in self.VALID_PHASES: + raise ValueError(f"phase must be one of {sorted(self.VALID_PHASES)}, got {phase!r}") + if expected_flag < 0: + raise ValueError(f"expected_flag must be non-negative, got {expected_flag}") + if active_mask is None: + if self._health is not None: + active_mask = self._health.get_mask() + else: + active_mask = (1 << self._ep_size) - 1 + if not (active_mask >> self._ep_rank) & 1: + raise ValueError("active_mask must include the local ep_rank") + + self.start() + with self._cv: + if self._closed: + raise RuntimeError("cannot queue a stopped AlltoAllWatchdog") + self._queue.append( + _CollectiveWatch( + phase=phase, + expected_flag=int(expected_flag), + active_mask=int(active_mask), + start_s=time.monotonic(), + ) + ) + self._cv.notify_all() + + def wait_until_idle(self, timeout_s: float) -> bool: + """Wait until all queued phases complete or timeout handling clears them.""" + deadline = time.monotonic() + timeout_s + with self._cv: + while self._queue: + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + self._cv.wait(timeout=remaining) + return True + + def __enter__(self) -> "AlltoAllWatchdog": + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.stop(timeout_s=1.0) + + def _active_ranks(self, active_mask: int) -> tuple[int, ...]: + return tuple(rank for rank in range(self._ep_size) if (active_mask >> rank) & 1) + + def _phase_complete(self, watch: _CollectiveWatch, observed_flags: tuple[int, ...]) -> bool: + return all( + observed_flags[rank] == watch.expected_flag + for rank in self._active_ranks(watch.active_mask) + ) + + def _missing_ranks( + self, watch: _CollectiveWatch, observed_flags: tuple[int, ...] + ) -> tuple[int, ...]: + return tuple( + rank + for rank in self._active_ranks(watch.active_mask) + if observed_flags[rank] != watch.expected_flag + ) + + def _handle_timeout( + self, + watch: _CollectiveWatch, + observed_flags: tuple[int, ...], + *, + poll_timed_out: bool = False, + ) -> None: + elapsed_s = time.monotonic() - watch.start_s + missing_ranks = self._missing_ranks(watch, observed_flags) + marked_failed: list[int] = [] + has_known_flags = UNKNOWN_COMPLETION_FLAG not in observed_flags + if self._health is not None and (has_known_flags or not poll_timed_out): + for rank in missing_ranks: + if rank == self._ep_rank: + continue + if self._health.mark_failed(rank): + marked_failed.append(rank) + + event = AlltoAllWatchdogTimeout( + phase=watch.phase, + expected_flag=watch.expected_flag, + observed_flags=observed_flags, + missing_ranks=missing_ranks, + marked_failed_ranks=tuple(marked_failed), + elapsed_s=elapsed_s, + poll_timed_out=poll_timed_out, + ) + if poll_timed_out: + tllm_logger.error( + "AlltoAll watchdog could not read completion flags on rank %d " + "during %s before timeout %.3fs; expected flag %d, active " + "ranks %s, observed flags %s, marked ranks %s", + self._ep_rank, + watch.phase, + elapsed_s, + watch.expected_flag, + list(self._active_ranks(watch.active_mask)), + list(observed_flags), + list(marked_failed), + ) + else: + tllm_logger.warning( + "AlltoAll watchdog timeout on rank %d during %s: expected flag %d, " + "missing ranks %s, observed flags %s", + self._ep_rank, + watch.phase, + watch.expected_flag, + list(missing_ranks), + list(observed_flags), + ) + if self._on_timeout is not None: + self._on_timeout(event) + + def _run(self) -> None: + last_observed_flags = tuple(UNKNOWN_COMPLETION_FLAG for _ in range(self._ep_size)) + poll_timed_out = False + while True: + with self._cv: + while not self._queue and not self._stopping: + self._cv.wait() + if self._stopping: + return + watch = self._queue[0] + + try: + observed_flags = tuple( + int(v) for v in self._completion_reader.read_completion_flags(watch.phase) + ) + if len(observed_flags) != self._ep_size: + raise RuntimeError( + f"completion reader returned {len(observed_flags)} flags; " + f"expected ep_size={self._ep_size}" + ) + last_observed_flags = observed_flags + poll_timed_out = False + except CompletionFlagReadTimeout: + observed_flags = last_observed_flags + poll_timed_out = True + except Exception as exc: # noqa: BLE001 - keep watchdog failures visible. + with self._cv: + self._last_error = exc + self._queue.clear() + self._cv.notify_all() + tllm_logger.error("AlltoAll watchdog stopped after polling error: %s", exc) + return + + if self._phase_complete(watch, observed_flags): + with self._cv: + if self._queue and self._queue[0] is watch: + self._queue.popleft() + self._cv.notify_all() + last_observed_flags = tuple(UNKNOWN_COMPLETION_FLAG for _ in range(self._ep_size)) + poll_timed_out = False + continue + + if time.monotonic() - watch.start_s >= self._timeout_s: + self._handle_timeout(watch, observed_flags, poll_timed_out=poll_timed_out) + with self._cv: + # The GPU stream is no longer trustworthy once a collective + # times out. Drop queued follow-on phases so they do not + # produce duplicate or misleading reports. + self._queue.clear() + self._cv.notify_all() + last_observed_flags = tuple(UNKNOWN_COMPLETION_FLAG for _ in range(self._ep_size)) + poll_timed_out = False + continue + + with self._cv: + self._cv.wait(timeout=self._poll_interval_s) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 77fc6f71eeb8..3f0cd8c87efd 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -466,6 +466,7 @@ def _( top_k: int, num_experts: int, eplb_local_stats: Optional[torch.Tensor] = None, + active_rank_mask: Optional[torch.Tensor] = None, ) -> Tuple[List[torch.Tensor], int, torch.Tensor]: recv_tensors: List[torch.Tensor] = [] for payload in input_payloads: @@ -496,6 +497,7 @@ def _( combine_payload_offset: int, payload_in_workspace: bool, use_low_precision: bool = False, + active_rank_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: return payload.new_empty((local_num_tokens, payload.shape[2])) diff --git a/tensorrt_llm/_torch/distributed/moe_alltoall.py b/tensorrt_llm/_torch/distributed/moe_alltoall.py index ca3a50dcfd12..b16c2a25bc6d 100644 --- a/tensorrt_llm/_torch/distributed/moe_alltoall.py +++ b/tensorrt_llm/_torch/distributed/moe_alltoall.py @@ -8,12 +8,18 @@ # ruff: noqa: E501 import os +import sys +import threading from dataclasses import dataclass -from typing import Dict, Optional +from typing import Callable, Dict, Optional import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory +from tensorrt_llm._torch.alltoall_watchdog import ( + DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, AlltoAllWatchdog, + AlltoAllWatchdogTimeout) from tensorrt_llm.bindings import internal as _tllm_internal from tensorrt_llm.logger import logger as tllm_logger from tensorrt_llm.mapping import Mapping @@ -126,6 +132,12 @@ def __init__( num_slots: int, workspace_size_per_rank: int, num_experts: Optional[int] = None, + ep_group_health=None, + alltoall_watchdog_timeout_s: Optional[float] = None, + alltoall_watchdog_poll_interval_s: + float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + alltoall_watchdog_on_timeout: Optional[Callable[ + [AlltoAllWatchdogTimeout], None]] = None, ): """ Initialize MoeAlltoAll with workspace allocation. @@ -138,6 +150,12 @@ def __init__( Note: The terminology is mapped to `num_experts` in this class and the kernels. num_experts: (Optional) Number of experts for EPLB stats (must be <= num_slots). DO NOT provide this parameter if EPLB is not enabled. Note: The terminology is mapped to `eplb_stats_num_experts` in this class and the kernels. + ep_group_health: Optional EPGroupHealth-compatible object. When present, its mask is passed to the + CUDA kernels and used by the watchdog. + alltoall_watchdog_timeout_s: Optional timeout for the host-side AlltoAll watchdog. If None, the + watchdog is disabled. + alltoall_watchdog_poll_interval_s: Poll interval for the watchdog thread. + alltoall_watchdog_on_timeout: Optional callback invoked when the watchdog reports suspects. """ # Check for environment variable override workspace_mb_env = os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB") @@ -195,6 +213,8 @@ def __init__( "mnnvl_mem": mnnvl_mem, "workspace": workspace, "metainfo": metainfo, + "watchdog_flag_generation": 0, + "watchdog_flag_generation_lock": threading.Lock(), } else: assert self._WORKSPACE[ @@ -212,8 +232,106 @@ def __init__( self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] self.workspace = self._WORKSPACE["workspace"] self.metainfo = self._WORKSPACE["metainfo"] + if "watchdog_flag_generation_lock" not in self._WORKSPACE: + self._WORKSPACE["watchdog_flag_generation_lock"] = threading.Lock() + self._WORKSPACE[ + "watchdog_flag_generation"] = self._read_current_flag_val() # Internal state self._state: _A2AState = _A2AState() + self.ep_group_health = ep_group_health + self._destroyed = False + self._alltoall_watchdog: AlltoAllWatchdog | None = None + if (alltoall_watchdog_timeout_s is None + and self.ep_group_health is not None): + alltoall_watchdog_timeout_s = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S + if alltoall_watchdog_timeout_s is not None: + self._sync_watchdog_flag_generation() + self._alltoall_watchdog = AlltoAllWatchdog.from_workspace( + workspace=self.workspace, + metainfo=self.metainfo, + metainfo_index=self._METAINFO_INDEX, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + timeout_s=alltoall_watchdog_timeout_s, + poll_interval_s=alltoall_watchdog_poll_interval_s, + health=self.ep_group_health, + on_timeout=alltoall_watchdog_on_timeout, + ) + + def destroy(self) -> None: + """Stop background watchdog resources owned by this wrapper.""" + if getattr(self, "_destroyed", False): + return + self._destroyed = True + watchdog = getattr(self, "_alltoall_watchdog", None) + if watchdog is not None: + watchdog.stop(timeout_s=1.0) + self._alltoall_watchdog = None + + def __del__(self) -> None: + if not sys.is_finalizing(): + self.destroy() + + def _read_current_flag_val(self) -> int: + flag_val_offset = self.metainfo[ + self._METAINFO_INDEX["FLAG_VAL_OFFSET_INDEX"]].item() + flag_val = self.workspace[self.ep_rank, + flag_val_offset:flag_val_offset + 4].view( + torch.int32) + if flag_val.device.type != "cpu": + flag_val = flag_val.detach().cpu() + return int(flag_val.item()) + + def _sync_watchdog_flag_generation(self) -> None: + workspace_state = self._WORKSPACE + assert workspace_state is not None + lock = workspace_state["watchdog_flag_generation_lock"] + with lock: + workspace_state["watchdog_flag_generation"] = max( + int(workspace_state["watchdog_flag_generation"]), + self._read_current_flag_val(), + ) + + def _next_watchdog_flag_generation(self) -> int: + workspace_state = self._WORKSPACE + assert workspace_state is not None + lock = workspace_state["watchdog_flag_generation_lock"] + with lock: + workspace_state["watchdog_flag_generation"] = ( + int(workspace_state["watchdog_flag_generation"]) + 1) + return int(workspace_state["watchdog_flag_generation"]) + + def _get_active_rank_mask_tensor( + self, + active_rank_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if active_rank_mask is not None: + return active_rank_mask + if self.ep_group_health is None: + return None + return torch.tensor(self.ep_group_health.get_mask_words(), + dtype=torch.uint64, + device="cpu") + + def _active_mask_int( + self, active_rank_mask: Optional[torch.Tensor]) -> Optional[int]: + if active_rank_mask is not None: + mask_cpu = active_rank_mask.detach().cpu() + return sum( + int(word) << (64 * idx) + for idx, word in enumerate(mask_cpu.tolist())) + if self.ep_group_health is not None: + return self.ep_group_health.get_mask() + return None + + def _watch_collective(self, phase: str, + active_rank_mask: Optional[torch.Tensor]) -> None: + if self._alltoall_watchdog is None: + return + self._alltoall_watchdog.watch( + phase=phase, + expected_flag=self._next_watchdog_flag_generation(), + active_mask=self._active_mask_int(active_rank_mask), + ) def dispatch(self, token_selected_experts: torch.Tensor, @@ -221,7 +339,8 @@ def dispatch(self, runtime_max_tokens_per_rank: int, invalid_token_expert_id: Optional[int] = None, expert_id_payload_index: Optional[int] = None, - eplb_local_stats: Optional[torch.Tensor] = None): + eplb_local_stats: Optional[torch.Tensor] = None, + active_rank_mask: Optional[torch.Tensor] = None): """ Perform MoE all-to-all dispatch operation. @@ -232,6 +351,7 @@ def dispatch(self, invalid_token_expert_id: If not None, set the token_selected_experts of the invalid tokens to this expert id. This is used to notify the MoE to skip these tokens for GroupGEMM. expert_id_payload_index: The index of token_selected_experts in the input_payloads. Must be provided if invalid_token_expert_id is not None. eplb_local_stats: (Optional) [num_experts] tensor containing local statistics for EPLB + active_rank_mask: Optional uint64 CPU tensor overriding ep_group_health for this dispatch. Returns: recv_tensors: List of tensors received, each has shape [ep_size, max_tokens_per_rank, payload_num_elements_per_token] @@ -246,6 +366,7 @@ def dispatch(self, 0 ) == self.eplb_stats_num_experts, "eplb_local_stats size must match eplb_stats_num_experts" + active_rank_mask = self._get_active_rank_mask_tensor(active_rank_mask) recv_tensors, combine_payload_offset, eplb_gathered_stats = torch.ops.trtllm.moe_a2a_dispatch( token_selected_experts, input_payloads, @@ -257,7 +378,9 @@ def dispatch(self, self.top_k, self.num_experts, eplb_local_stats, + active_rank_mask, ) + self._watch_collective("dispatch", active_rank_mask) if eplb_gathered_stats.numel() == 0: eplb_gathered_stats = None @@ -287,6 +410,7 @@ def combine( runtime_max_tokens_per_rank: int, payload_in_workspace: bool = False, use_low_precision_combine: bool = False, + active_rank_mask: Optional[torch.Tensor] = None, ): """ Perform MoE all-to-all combine operation. @@ -296,6 +420,7 @@ def combine( runtime_max_tokens_per_rank: Maximum of the number of tokens of each DP rank's local batch. payload_in_workspace: If True, 'payload' is a view into 'workspace' at 'combine_payload_offset' and no staging copy is needed. If False, the op stages 'payload' into the workspace region before combining. use_low_precision_combine: If True, quantize the combine payload to FP8 for NVLink transfer (halves NVLink bandwidth usage, output precision is preserved). + active_rank_mask: Optional uint64 CPU tensor overriding ep_group_health for this combine. Returns: combined_output: [local_num_tokens, num_elements_per_token] tensor of combined results @@ -303,11 +428,13 @@ def combine( assert self._state.phase == "dispatched", "combine called before a successful dispatch" assert runtime_max_tokens_per_rank <= self.max_num_tokens, "runtime_max_tokens_per_rank must not exceed max_num_tokens" + active_rank_mask = self._get_active_rank_mask_tensor(active_rank_mask) output = torch.ops.trtllm.moe_a2a_combine( payload, self._state.local_num_tokens, self.workspace, self.metainfo, runtime_max_tokens_per_rank, self.ep_rank, self.ep_size, self.top_k, self._state.combine_payload_offset, - payload_in_workspace, use_low_precision_combine) + payload_in_workspace, use_low_precision_combine, active_rank_mask) + self._watch_collective("combine", active_rank_mask) # Reset state for next round self.reset_state() diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index a4ec2ceefe44..17a67deee219 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -28,6 +28,7 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm.logger import logger +from ..wide_ep_ft import get_wide_ep_ft_options from .allgather_reducescatter import AllGatherReduceScatter from .base import Communication from .deep_ep import DeepEP @@ -133,6 +134,9 @@ def create_strategy( try: enable_eplb = model_config.moe_load_balancer is not None + ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = get_wide_ep_ft_options( + model_config + ) strategy = NVLinkOneSided( mapping, num_slots, @@ -143,6 +147,9 @@ def create_strategy( dtype=act_dtype, num_experts=num_experts if enable_eplb else None, use_low_precision_combine=use_low_precision_combine, + ep_group_health=ep_group_health, + alltoall_watchdog_timeout_s=watchdog_timeout_s, + alltoall_watchdog_poll_interval_s=watchdog_poll_interval_s, ) logger.info("Selected communication strategy: NVLinkOneSided") return strategy @@ -285,6 +292,9 @@ def _create_forced_method( ) elif method in ["NVLINK_ONE_SIDED"]: enable_eplb = model_config.moe_load_balancer is not None + ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = get_wide_ep_ft_options( + model_config + ) return NVLinkOneSided( mapping, num_slots, @@ -295,6 +305,9 @@ def _create_forced_method( dtype=act_dtype, num_experts=num_experts if enable_eplb else None, use_low_precision_combine=use_low_precision_combine, + ep_group_health=ep_group_health, + alltoall_watchdog_timeout_s=watchdog_timeout_s, + alltoall_watchdog_poll_interval_s=watchdog_poll_interval_s, ) elif method == "DEEPEP": return DeepEP( diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index 3b634dd7072c..db5966615356 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -25,11 +25,18 @@ """ import os -from typing import Dict, List, Optional, Tuple +import threading +from typing import Callable, Dict, List, Optional, Tuple import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory +from tensorrt_llm._torch.alltoall_watchdog import ( + DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, + AlltoAllWatchdog, + AlltoAllWatchdogTimeout, +) from tensorrt_llm.bindings import internal as _tllm_internal from tensorrt_llm.logger import logger as tllm_logger from tensorrt_llm.mapping import Mapping @@ -51,7 +58,7 @@ class NVLinkOneSided(Communication): """ # Constants from C++ (must match moeAlltoAllKernels.h) - MAX_RANKS = 64 + MAX_RANKS = 128 MAX_TOP_K = 8 MAX_PAYLOADS = 8 @@ -151,6 +158,10 @@ def __init__( dtype: Optional[torch.dtype] = None, num_experts: Optional[int] = None, use_low_precision_combine: bool = False, + ep_group_health=None, + alltoall_watchdog_timeout_s: Optional[float] = None, + alltoall_watchdog_poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + alltoall_watchdog_on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None, ): """ Initialize NVLinkOneSided with workspace allocation. @@ -169,6 +180,12 @@ def __init__( use_low_precision_combine: If True, quantize the combine payload to FP8 for NVLink transfer (halves NVLink bandwidth usage, output precision is preserved). Corresponds to model_config.use_low_precision_moe_combine. + ep_group_health: Optional EPGroupHealth-compatible object. When present, its mask is passed to the + CUDA kernels and used by the watchdog. + alltoall_watchdog_timeout_s: Optional timeout for the host-side AlltoAll watchdog. If None, the + watchdog is disabled. + alltoall_watchdog_poll_interval_s: Poll interval for the watchdog thread. + alltoall_watchdog_on_timeout: Optional callback invoked when the watchdog reports suspects. """ super().__init__(mapping) @@ -271,6 +288,8 @@ def __init__( "mnnvl_mem": mnnvl_mem, "workspace": workspace, "metainfo": metainfo, + "watchdog_flag_generation": 0, + "watchdog_flag_generation_lock": threading.Lock(), } NVLinkOneSided._WORKSPACES[self._workspace_key] = workspace_state else: @@ -296,10 +315,35 @@ def __init__( NVLinkOneSided._WORKSPACE_REFCOUNTS.get(self._workspace_key, 0) + 1 ) self._destroyed = False + self._workspace_state = workspace_state self.mnnvl_mem = workspace_state["mnnvl_mem"] self.workspace = workspace_state["workspace"] self.moe_a2a_metainfo = workspace_state["metainfo"] self.max_num_tokens_per_rank = workspace_state["max_num_tokens_per_rank"] + if "watchdog_flag_generation_lock" not in workspace_state: + workspace_state["watchdog_flag_generation_lock"] = threading.Lock() + workspace_state["watchdog_flag_generation"] = self._read_current_flag_val() + self.ep_group_health = ep_group_health + self._alltoall_watchdog: AlltoAllWatchdog | None = None + if alltoall_watchdog_timeout_s is None and self.ep_group_health is not None: + alltoall_watchdog_timeout_s = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S + if alltoall_watchdog_timeout_s is not None: + self._sync_watchdog_flag_generation() + self._alltoall_watchdog = AlltoAllWatchdog.from_workspace( + workspace=self.workspace, + metainfo=self.moe_a2a_metainfo, + metainfo_index={ + "FLAG_VAL_OFFSET_INDEX": self.FLAG_VAL_OFFSET_INDEX, + "DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX": self.DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX, + "COMBINE_COMPLETION_FLAGS_OFFSET_INDEX": self.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX, + }, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + timeout_s=alltoall_watchdog_timeout_s, + poll_interval_s=alltoall_watchdog_poll_interval_s, + health=self.ep_group_health, + on_timeout=alltoall_watchdog_on_timeout, + ) # Initialize dispatch state self._dispatch_state = {"phase": "idle"} @@ -307,6 +351,57 @@ def __init__( # Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only. self.invalid_token_expert_id: int = -1 + def _read_current_flag_val(self) -> int: + flag_val_offset = self.moe_a2a_metainfo[self.FLAG_VAL_OFFSET_INDEX].item() + flag_val = self.workspace[self.ep_rank, flag_val_offset : flag_val_offset + 4].view( + torch.int32 + ) + if flag_val.device.type != "cpu": + flag_val = flag_val.detach().cpu() + return int(flag_val.item()) + + def _sync_watchdog_flag_generation(self) -> None: + lock = self._workspace_state["watchdog_flag_generation_lock"] + with lock: + self._workspace_state["watchdog_flag_generation"] = max( + int(self._workspace_state["watchdog_flag_generation"]), + self._read_current_flag_val(), + ) + + def _next_watchdog_flag_generation(self) -> int: + lock = self._workspace_state["watchdog_flag_generation_lock"] + with lock: + self._workspace_state["watchdog_flag_generation"] = ( + int(self._workspace_state["watchdog_flag_generation"]) + 1 + ) + return int(self._workspace_state["watchdog_flag_generation"]) + + def _get_active_rank_mask_tensor( + self, active_rank_mask: Optional[torch.Tensor] + ) -> Optional[torch.Tensor]: + if active_rank_mask is not None: + return active_rank_mask + if self.ep_group_health is None: + return None + return torch.tensor(self.ep_group_health.get_mask_words(), dtype=torch.uint64, device="cpu") + + def _active_mask_int(self, active_rank_mask: Optional[torch.Tensor]) -> Optional[int]: + if active_rank_mask is not None: + mask_cpu = active_rank_mask.detach().cpu() + return sum(int(word) << (64 * idx) for idx, word in enumerate(mask_cpu.tolist())) + if self.ep_group_health is not None: + return self.ep_group_health.get_mask() + return None + + def _watch_collective(self, phase: str, active_rank_mask: Optional[torch.Tensor]) -> None: + if self._alltoall_watchdog is None: + return + self._alltoall_watchdog.watch( + phase=phase, + expected_flag=self._next_watchdog_flag_generation(), + active_mask=self._active_mask_int(active_rank_mask), + ) + @staticmethod def is_platform_supported() -> bool: """ @@ -326,6 +421,9 @@ def destroy(self): return self._destroyed = True + if self._alltoall_watchdog is not None: + self._alltoall_watchdog.stop(timeout_s=1.0) + self._alltoall_watchdog = None workspace_key = getattr(self, "_workspace_key", None) if workspace_key is None: return @@ -347,6 +445,7 @@ def destroy(self): self.mnnvl_mem = None self.workspace = None self.moe_a2a_metainfo = None + self._workspace_state = None self._dispatch_state = {"phase": "destroyed"} def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: @@ -409,6 +508,7 @@ def dispatch( assert eplb_local_stats.size(0) == self.eplb_stats_num_experts, ( "eplb_local_stats size must match eplb_stats_num_experts" ) + active_rank_mask = self._get_active_rank_mask_tensor(kwargs.get("active_rank_mask")) recv_buffers, combine_payload_offset, eplb_gathered_stats = ( torch.ops.trtllm.moe_a2a_dispatch( @@ -422,8 +522,10 @@ def dispatch( self.top_k, self.num_experts, eplb_local_stats, + active_rank_mask, ) ) + self._watch_collective("dispatch", active_rank_mask) if eplb_gathered_stats.numel() == 0: eplb_gathered_stats = None self._dispatch_state["eplb_gathered_stats"] = eplb_gathered_stats @@ -526,6 +628,7 @@ def combine( raise ValueError( f"final_hidden_states must be 2D or 3D, got {final_hidden_states.dim()}D" ) + active_rank_mask = self._get_active_rank_mask_tensor(kwargs.get("active_rank_mask")) output = torch.ops.trtllm.moe_a2a_combine( final_hidden_states, int(local_num_tokens), @@ -538,7 +641,9 @@ def combine( int(combine_payload_offset), bool(self.payload_in_workspace), bool(self.use_low_precision_combine), + active_rank_mask, ) + self._watch_collective("combine", active_rank_mask) # Reset state for next round self.reset_state() diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index acae5cd37ba3..64bc178e64f3 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -20,6 +20,7 @@ Fp4QuantizedTensor) from .interface import AlltoallMethodType, MoE from .quantization import UnquantizedFusedMoEMethod +from .wide_ep_ft import get_wide_ep_ft_options # isort: off from .quantization import ( @@ -324,6 +325,8 @@ def __init__( dtype, self.num_experts if self.layer_load_balancer else None, ) + ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = ( + get_wide_ep_ft_options(model_config)) self.moe_a2a = MoeAlltoAll( mapping=self.mapping, @@ -333,6 +336,10 @@ def __init__( workspace_size_per_rank=workspace_size, num_experts=self.num_experts if self.layer_load_balancer else None, + ep_group_health=ep_group_health, + alltoall_watchdog_timeout_s=watchdog_timeout_s, + alltoall_watchdog_poll_interval_s= + watchdog_poll_interval_s, ) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 820d92a95e4a..6f59870c374b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -36,6 +36,7 @@ from ...utils import ActivationType, AuxStreamType, Fp4QuantizedTensor from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode from .moe_op_backend import MoEOpBackend, get_op_backend +from .wide_ep_ft import get_wide_ep_ft_options # isort: off from .quantization import ( @@ -291,6 +292,8 @@ def __init__( ep_size, self.routing_method.experts_per_token, max_num_tokens, hidden_size, dtype, self.num_experts if self.layer_load_balancer else None) + ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = ( + get_wide_ep_ft_options(model_config)) self.moe_a2a = MoeAlltoAll( mapping=self.mapping, @@ -299,7 +302,11 @@ def __init__( num_slots=self.num_slots, workspace_size_per_rank=workspace_size, num_experts=self.num_experts - if self.layer_load_balancer else None) + if self.layer_load_balancer else None, + ep_group_health=ep_group_health, + alltoall_watchdog_timeout_s=watchdog_timeout_s, + alltoall_watchdog_poll_interval_s= + watchdog_poll_interval_s) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( "DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet" diff --git a/tensorrt_llm/_torch/modules/fused_moe/wide_ep_ft.py b/tensorrt_llm/_torch/modules/fused_moe/wide_ep_ft.py new file mode 100644 index 000000000000..69e5c53f061e --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/wide_ep_ft.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared WideEP fault-tolerance options for MoE communication paths.""" + +from __future__ import annotations + +import os +from typing import Any, Optional + +from tensorrt_llm._torch.alltoall_watchdog import ( + DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, +) + +from .ep_group_health import EPGroupHealth + +_ENABLE_ENV = "TRTLLM_ENABLE_WIDE_EP_FT" +_TIMEOUT_ENV = "TRTLLM_ALLTOALL_WATCHDOG_TIMEOUT_S" +_POLL_INTERVAL_ENV = "TRTLLM_ALLTOALL_WATCHDOG_POLL_INTERVAL_S" + +_HEALTH_KEY = "wide_ep_ft_ep_group_health" +_TIMEOUT_KEY = "alltoall_watchdog_timeout_s" +_POLL_INTERVAL_KEY = "alltoall_watchdog_poll_interval_s" + + +def _env_enabled() -> bool: + return os.environ.get(_ENABLE_ENV, "0").lower() in {"1", "true", "yes", "on"} + + +def _float_option(extra_attrs: dict, key: str, env_name: str, default: float) -> float: + if key in extra_attrs: + return float(extra_attrs[key]) + if env_name in os.environ: + return float(os.environ[env_name]) + return default + + +def get_wide_ep_ft_options( + model_config: Any, +) -> tuple[Optional[EPGroupHealth], Optional[float], float]: + """Return the shared EP health object and watchdog timing for a model. + + WideEP FT remains opt-in until the integration PR wires a public model + option. Callers can either inject ``wide_ep_ft_ep_group_health`` through + ``ModelConfig.extra_attrs`` or set ``TRTLLM_ENABLE_WIDE_EP_FT=1`` to create + one process-local health object shared by all MoE communication layers. + """ + + extra_attrs = getattr(model_config, "extra_attrs", {}) + health = extra_attrs.get(_HEALTH_KEY) or extra_attrs.get("ep_group_health") + if health is None and _env_enabled(): + health = EPGroupHealth(model_config.mapping.moe_ep_size) + extra_attrs[_HEALTH_KEY] = health + + poll_interval_s = _float_option( + extra_attrs, + _POLL_INTERVAL_KEY, + _POLL_INTERVAL_ENV, + DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + ) + if health is None: + return None, None, poll_interval_s + + timeout_s = _float_option( + extra_attrs, + _TIMEOUT_KEY, + _TIMEOUT_ENV, + DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, + ) + return health, timeout_s, poll_interval_s diff --git a/tests/unittest/_torch/modules/moe/test_moe_comm.py b/tests/unittest/_torch/modules/moe/test_moe_comm.py index c59751f42025..edb127a3c355 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_comm.py +++ b/tests/unittest/_torch/modules/moe/test_moe_comm.py @@ -114,6 +114,10 @@ # to avoid _WORKSPACE singleton assertion failures. NVLINK_WORKSPACE_MB = "512" +# Must match kRankMaskWords in +# cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h. +EP_MASK_NUM_WORDS = 2 + # ============================================================================ # Test Configuration @@ -192,6 +196,182 @@ def _safe_cpu(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return t.cpu() +def _ep_mask_words(ep_size: int, dead_ranks: Set[int]) -> torch.Tensor: + """Build the uint64[EP_MASK_NUM_WORDS] CPU tensor expected by moe_a2a ops.""" + mask_int = ((1 << ep_size) - 1) & ~sum(1 << rank for rank in dead_ranks) + word_mask = (1 << 64) - 1 + words = [(mask_int >> (i * 64)) & word_mask for i in range(EP_MASK_NUM_WORDS)] + return torch.tensor(words, dtype=torch.uint64, device="cpu") + + +def _make_rank_mask_payload(local_num_tokens: int, hidden_size: int, rank: int) -> torch.Tensor: + """Make deterministic per-rank payloads for exact equality assertions.""" + base = torch.arange(local_num_tokens * hidden_size, dtype=torch.bfloat16, device="cuda").view( + local_num_tokens, hidden_size + ) + return base + (rank * 1000.0) + + +def _read_nvlink_topk_target_ranks( + comm: NVLinkOneSided, + max_num_tokens: int, + top_k: int, +) -> torch.Tensor: + """Read topk_target_ranks[max_num_tokens, top_k] from NVLinkOneSided workspace.""" + from tensorrt_llm.bindings import internal as _tllm_internal + + offset_index = int(_tllm_internal.thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX) + offset = comm.moe_a2a_metainfo[offset_index].item() + raw = comm.workspace[ + comm.ep_rank, + offset : offset + max_num_tokens * top_k * 4, + ] + return raw.view(torch.int32).view(max_num_tokens, top_k).cpu() + + +def _read_nvlink_topk_send_indices( + comm: NVLinkOneSided, + max_num_tokens: int, + top_k: int, +) -> torch.Tensor: + """Read topk_send_indices[max_num_tokens, top_k] from NVLinkOneSided workspace.""" + from tensorrt_llm.bindings import internal as _tllm_internal + + offset_index = int(_tllm_internal.thop.MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX) + offset = comm.moe_a2a_metainfo[offset_index].item() + raw = comm.workspace[ + comm.ep_rank, + offset : offset + max_num_tokens * top_k * 4, + ] + return raw.view(torch.int32).view(max_num_tokens, top_k).cpu() + + +def _run_nvlink_rank_mask_dispatch( + comm: NVLinkOneSided, + token_selected_experts: torch.Tensor, + payload: torch.Tensor, + runtime_max_tokens_per_rank: int, + active_rank_mask: Optional[torch.Tensor], +) -> Tuple[List[torch.Tensor], int, torch.Tensor, torch.Tensor]: + """Run raw NVLink one-sided dispatch with an optional active rank mask.""" + recv_tensors, combine_payload_offset, _ = torch.ops.trtllm.moe_a2a_dispatch( + token_selected_experts, + [payload], + comm.workspace, + comm.moe_a2a_metainfo, + runtime_max_tokens_per_rank, + comm.ep_rank, + comm.ep_size, + comm.top_k, + comm.num_experts, + None, # eplb_local_stats + active_rank_mask, + ) + + topk_target_ranks = _read_nvlink_topk_target_ranks( + comm, + runtime_max_tokens_per_rank, + comm.top_k, + ) + topk_send_indices = _read_nvlink_topk_send_indices( + comm, + runtime_max_tokens_per_rank, + comm.top_k, + ) + return recv_tensors, int(combine_payload_offset), topk_target_ranks, topk_send_indices + + +def _run_nvlink_rank_mask_combine( + comm: NVLinkOneSided, + combine_payload: torch.Tensor, + local_num_tokens: int, + runtime_max_tokens_per_rank: int, + combine_payload_offset: int, + active_rank_mask: Optional[torch.Tensor], +) -> torch.Tensor: + """Run raw NVLink one-sided combine with an optional active rank mask.""" + return torch.ops.trtllm.moe_a2a_combine( + combine_payload, + local_num_tokens, + comm.workspace, + comm.moe_a2a_metainfo, + runtime_max_tokens_per_rank, + comm.ep_rank, + comm.ep_size, + comm.top_k, + combine_payload_offset, + False, # payload_in_workspace + False, # use_low_precision + active_rank_mask, + ) + + +def _run_nvlink_rank_mask_dispatch_combine( + comm: NVLinkOneSided, + token_selected_experts: torch.Tensor, + payload: torch.Tensor, + runtime_max_tokens_per_rank: int, + active_rank_mask: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Run raw NVLink one-sided dispatch/combine with an optional active rank mask.""" + recv_tensors, combine_payload_offset, topk_target_ranks, _ = _run_nvlink_rank_mask_dispatch( + comm, + token_selected_experts, + payload, + runtime_max_tokens_per_rank, + active_rank_mask, + ) + combined = _run_nvlink_rank_mask_combine( + comm, + recv_tensors[0], + token_selected_experts.size(0), + runtime_max_tokens_per_rank, + combine_payload_offset, + active_rank_mask, + ) + return combined.cpu(), topk_target_ranks + + +def _expected_nvlink_rank_mask_combine_output( + comm: NVLinkOneSided, + payload: torch.Tensor, + topk_target_ranks: torch.Tensor, + topk_send_indices: torch.Tensor, + local_num_tokens: int, + runtime_max_tokens_per_rank: int, + dead_ranks: Set[int], +) -> torch.Tensor: + """Compute combine output from dispatched workspace while skipping dead ranks.""" + from tensorrt_llm.bindings import internal as _tllm_internal + + hidden_size = payload.shape[-1] + expected = torch.zeros( + (local_num_tokens, hidden_size), + dtype=payload.dtype, + device=payload.device, + ) + payload_offset_index = int(_tllm_internal.thop.MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX) + payload_offset = comm.moe_a2a_metainfo[payload_offset_index].item() + bytes_per_rank = ( + comm.ep_size * runtime_max_tokens_per_rank * hidden_size * payload.element_size() + ) + + for token_idx in range(local_num_tokens): + for k in range(comm.top_k): + target_rank = int(topk_target_ranks[token_idx, k].item()) + dst_idx = int(topk_send_indices[token_idx, k].item()) + if dst_idx < 0 or target_rank in dead_ranks: + continue + raw = comm.workspace[target_rank, payload_offset : payload_offset + bytes_per_rank] + recv_payload = raw.view(payload.dtype).view( + comm.ep_size, + runtime_max_tokens_per_rank, + hidden_size, + ) + expected[token_idx] += recv_payload[comm.ep_rank, dst_idx] + return expected.cpu() + + # ============================================================================ # Source Encoding Utilities # ============================================================================ @@ -856,6 +1036,233 @@ def _worker_full_pipeline(config: CommTestConfig) -> dict: comm.destroy() +def _make_rank_mask_config( + ep_size: int, + local_num_tokens: int, + top_k: int, +) -> CommTestConfig: + """Build the small NVLinkOneSided config used by active-rank-mask tests.""" + return CommTestConfig( + comm_type=COMM_NVLINK_ONE_SIDED, + ep_size=ep_size, + num_experts=FIXED_NUM_EXPERTS, + top_k=top_k, + hidden_size=1024, + all_num_tokens=[local_num_tokens] * ep_size, + ) + + +def _worker_rank_mask_all_active_matches_no_mask(config: CommTestConfig) -> dict: + """Check that all-active active_rank_mask is bit-identical to no mask.""" + rank = tllm.mpi_rank() + torch.cuda.set_device(rank) + + comm = None + try: + mapping = Mapping( + rank=rank, + tp_size=config.ep_size, + moe_ep_size=config.ep_size, + world_size=config.ep_size, + ) + comm = create_comm_object(config.comm_type, mapping, config) + + local_num_tokens = config.all_num_tokens[rank] + torch.manual_seed(0xA2A + rank) + token_selected_experts = torch.randint( + 0, + config.num_experts, + (local_num_tokens, config.top_k), + dtype=torch.int32, + device="cuda", + ) + payload = _make_rank_mask_payload(local_num_tokens, config.hidden_size, rank) + + out_no_mask, topk_no_mask = _run_nvlink_rank_mask_dispatch_combine( + comm, + token_selected_experts, + payload, + local_num_tokens, + active_rank_mask=None, + ) + out_all_active, topk_all_active = _run_nvlink_rank_mask_dispatch_combine( + comm, + token_selected_experts, + payload, + local_num_tokens, + active_rank_mask=_ep_mask_words(config.ep_size, dead_ranks=set()), + ) + + return { + "rank": rank, + "output_eq": torch.equal(out_no_mask, out_all_active), + "topk_eq": torch.equal(topk_no_mask, topk_all_active), + } + except Exception: + traceback.print_exc() + raise + finally: + if comm is not None and hasattr(comm, "destroy"): + comm.destroy() + + +def _expected_target_ranks( + token_selected_experts: torch.Tensor, + num_experts: int, + ep_size: int, +) -> torch.Tensor: + """Map each selected expert to its target EP rank using the kernel partition rule.""" + token_selected_experts_cpu = token_selected_experts.cpu() + expected = torch.empty_like(token_selected_experts_cpu) + for token_idx in range(token_selected_experts_cpu.shape[0]): + for k in range(token_selected_experts_cpu.shape[1]): + expert_id = int(token_selected_experts_cpu[token_idx, k].item()) + expected[token_idx, k] = _expert_id_to_rank(expert_id, num_experts, ep_size) + return expected + + +def _worker_rank_mask_one_rank_masked( + config: CommTestConfig, + dead_rank: int, +) -> dict: + """Run dispatch/combine with one EP rank omitted from active_rank_mask.""" + rank = tllm.mpi_rank() + torch.cuda.set_device(rank) + + comm = None + try: + mapping = Mapping( + rank=rank, + tp_size=config.ep_size, + moe_ep_size=config.ep_size, + world_size=config.ep_size, + ) + # All ranks must initialize the symmetric workspace before the dead rank + # stops participating in dispatch/combine. + comm = create_comm_object(config.comm_type, mapping, config) + + if rank == dead_rank: + MPI.COMM_WORLD.barrier() + return {"rank": rank, "status": "dead"} + + local_num_tokens = config.all_num_tokens[rank] + torch.manual_seed(0xA2A + rank) + token_selected_experts = torch.randint( + 0, + config.num_experts, + (local_num_tokens, config.top_k), + dtype=torch.int32, + device="cuda", + ) + payload = _make_rank_mask_payload(local_num_tokens, config.hidden_size, rank) + mask = _ep_mask_words(config.ep_size, dead_ranks={dead_rank}) + + combined, topk_target_ranks = _run_nvlink_rank_mask_dispatch_combine( + comm, + token_selected_experts, + payload, + local_num_tokens, + active_rank_mask=mask, + ) + expected_target_ranks = _expected_target_ranks( + token_selected_experts, + config.num_experts, + config.ep_size, + ) + + MPI.COMM_WORLD.barrier() + return { + "rank": rank, + "status": "alive", + "combined": combined, + "topk_target_ranks": topk_target_ranks, + "expected_target_ranks": expected_target_ranks, + } + except Exception: + traceback.print_exc() + raise + finally: + if comm is not None and hasattr(comm, "destroy"): + comm.destroy() + + +def _worker_rank_mask_inactive_before_combine( + config: CommTestConfig, + dead_rank: int, +) -> dict: + """Dispatch with all ranks active, then omit one rank from combine's active mask.""" + rank = tllm.mpi_rank() + torch.cuda.set_device(rank) + + comm = None + try: + mapping = Mapping( + rank=rank, + tp_size=config.ep_size, + moe_ep_size=config.ep_size, + world_size=config.ep_size, + ) + comm = create_comm_object(config.comm_type, mapping, config) + + local_num_tokens = config.all_num_tokens[rank] + torch.manual_seed(0xA2A + rank) + token_selected_experts = torch.randint( + 0, + config.num_experts, + (local_num_tokens, config.top_k), + dtype=torch.int32, + device="cuda", + ) + payload = _make_rank_mask_payload(local_num_tokens, config.hidden_size, rank) + + recv_tensors, combine_payload_offset, topk_target_ranks, topk_send_indices = ( + _run_nvlink_rank_mask_dispatch( + comm, + token_selected_experts, + payload, + local_num_tokens, + active_rank_mask=_ep_mask_words(config.ep_size, dead_ranks=set()), + ) + ) + + if rank == dead_rank: + MPI.COMM_WORLD.barrier() + return {"rank": rank, "status": "dead"} + + dead_ranks = {dead_rank} + expected = _expected_nvlink_rank_mask_combine_output( + comm, + payload, + topk_target_ranks, + topk_send_indices, + local_num_tokens, + local_num_tokens, + dead_ranks, + ) + combined = _run_nvlink_rank_mask_combine( + comm, + recv_tensors[0], + local_num_tokens, + local_num_tokens, + combine_payload_offset, + active_rank_mask=_ep_mask_words(config.ep_size, dead_ranks=dead_ranks), + ).cpu() + + MPI.COMM_WORLD.barrier() + return { + "rank": rank, + "status": "alive", + "combined": combined, + "expected": expected, + } + except Exception: + traceback.print_exc() + raise + finally: + if comm is not None and hasattr(comm, "destroy"): + comm.destroy() + + # ============================================================================ # Verification Functions # ============================================================================ @@ -1581,7 +1988,7 @@ def _make_postquant_test_params(): @pytest.fixture(autouse=True) -def setup_test(): +def setup_test() -> None: torch.manual_seed(0x1234) tllm.logger.set_level("error") @@ -1630,6 +2037,143 @@ def _run_full_test(mpi_pool_executor, config: CommTestConfig): verify_combine_results(all_results, config, rtol=0.02, atol=0.15) +def _skip_if_rank_mask_config_unsupported(config: CommTestConfig) -> None: + """Skip active-rank-mask tests when NVLinkOneSided cannot run locally.""" + skip_reason = check_platform_support(config.comm_type) + if skip_reason: + pytest.skip(skip_reason) + + skip_reason = check_feasibility(config.comm_type, config) + if skip_reason: + pytest.skip(skip_reason) + + if config.ep_size > torch.cuda.device_count(): + pytest.skip(f"Need {config.ep_size} GPUs but only {torch.cuda.device_count()} available") + + +def _run_rank_mask_all_active_test( + mpi_pool_executor, + local_num_tokens: int, + top_k: int, +) -> None: + ep_size = mpi_pool_executor.num_workers + config = _make_rank_mask_config(ep_size, local_num_tokens, top_k) + _skip_if_rank_mask_config_unsupported(config) + + results = list( + mpi_pool_executor.map( + _worker_rank_mask_all_active_matches_no_mask, + *zip(*[(config,)] * config.ep_size), + ) + ) + + for result in results: + rank = result["rank"] + assert result["output_eq"], ( + f"rank {rank}: combine output differs between no-mask and all-active mask" + ) + assert result["topk_eq"], ( + f"rank {rank}: topk_target_ranks differ between no-mask and all-active mask" + ) + + +def _run_rank_mask_one_rank_masked_test( + mpi_pool_executor, + dead_rank: int, + local_num_tokens: int, + top_k: int, +) -> None: + ep_size = mpi_pool_executor.num_workers + config = _make_rank_mask_config(ep_size, local_num_tokens, top_k) + _skip_if_rank_mask_config_unsupported(config) + assert 0 <= dead_rank < ep_size + + worker_args = [(config, dead_rank)] * config.ep_size + results = list( + mpi_pool_executor.map( + _worker_rank_mask_one_rank_masked, + *zip(*worker_args), + ) + ) + + saw_dead = False + for result in results: + rank = result["rank"] + if result["status"] == "dead": + assert rank == dead_rank + saw_dead = True + continue + + assert result["status"] == "alive" + combined = result["combined"] + topk_target_ranks = result["topk_target_ranks"] + expected_target_ranks = result["expected_target_ranks"] + + assert combined.shape == (local_num_tokens, config.hidden_size) + + live_topk = topk_target_ranks[:local_num_tokens] + live_expected = expected_target_ranks[:local_num_tokens] + for token_idx in range(local_num_tokens): + seen_ranks: Set[int] = set() + for k in range(top_k): + expected = int(live_expected[token_idx, k].item()) + got = int(live_topk[token_idx, k].item()) + if expected == dead_rank: + assert got == -1, ( + f"rank {rank} token {token_idx} k={k}: token routed to dead " + f"rank {dead_rank} should have been dropped (got={got})" + ) + elif expected in seen_ranks: + assert got == -1 + else: + assert got == expected, ( + f"rank {rank} token {token_idx} k={k}: target rank mismatch " + f"(expected={expected}, got={got})" + ) + seen_ranks.add(expected) + + assert saw_dead, f"dead rank {dead_rank} did not appear in results" + + +def _run_rank_mask_inactive_before_combine_test( + mpi_pool_executor, + dead_rank: int, + local_num_tokens: int, + top_k: int, +) -> None: + ep_size = mpi_pool_executor.num_workers + config = _make_rank_mask_config(ep_size, local_num_tokens, top_k) + _skip_if_rank_mask_config_unsupported(config) + assert 0 <= dead_rank < ep_size + + worker_args = [(config, dead_rank)] * config.ep_size + results = list( + mpi_pool_executor.map( + _worker_rank_mask_inactive_before_combine, + *zip(*worker_args), + ) + ) + + saw_dead = False + for result in results: + rank = result["rank"] + if result["status"] == "dead": + assert rank == dead_rank + saw_dead = True + continue + + assert result["status"] == "alive" + combined = result["combined"] + expected = result["expected"] + assert combined is not None + assert expected is not None + assert torch.equal(combined, expected), ( + f"rank {rank}: combine output included a rank masked inactive before combine" + ) + + assert saw_dead, f"dead rank {dead_rank} did not appear in results" + + # ============================================================================ # Test Class # ============================================================================ @@ -1682,3 +2226,69 @@ def test_moe_comm_postquant(self, mpi_pool_executor, config: CommTestConfig): def test_moe_comm_non_divisible_ep(self, mpi_pool_executor, config: CommTestConfig): """Verify NVLinkOneSided with non-divisible EP (num_experts % ep_size != 0).""" _run_full_test(mpi_pool_executor, config) + + @pytest.mark.threadleak(enabled=False) + @pytest.mark.parametrize( + "mpi_pool_executor,local_num_tokens,top_k", + [ + (4, 16, 2), + (4, 32, 4), + ], + indirect=["mpi_pool_executor"], + ) + def test_moe_comm_rank_mask_all_active_matches_no_mask( + self, + mpi_pool_executor, + local_num_tokens: int, + top_k: int, + ) -> None: + """Verify all-active active_rank_mask matches omitted mask for NVLinkOneSided.""" + _run_rank_mask_all_active_test(mpi_pool_executor, local_num_tokens, top_k) + + @pytest.mark.threadleak(enabled=False) + @pytest.mark.parametrize( + "mpi_pool_executor,dead_rank,local_num_tokens,top_k", + [ + (4, 2, 16, 2), + (4, 0, 16, 4), + (4, 3, 32, 4), + ], + indirect=["mpi_pool_executor"], + ) + def test_moe_comm_rank_mask_one_rank_masked_completes( + self, + mpi_pool_executor, + dead_rank: int, + local_num_tokens: int, + top_k: int, + ) -> None: + """Verify masked-dead rank is skipped by raw NVLinkOneSided moe_a2a ops.""" + _run_rank_mask_one_rank_masked_test( + mpi_pool_executor, + dead_rank, + local_num_tokens, + top_k, + ) + + @pytest.mark.threadleak(enabled=False) + @pytest.mark.parametrize( + "mpi_pool_executor,dead_rank,local_num_tokens,top_k", + [ + (4, 2, 16, 2), + ], + indirect=["mpi_pool_executor"], + ) + def test_moe_comm_rank_mask_inactive_before_combine_skips_stale_dispatch_slots( + self, + mpi_pool_executor, + dead_rank: int, + local_num_tokens: int, + top_k: int, + ) -> None: + """Verify combine skips slots from a rank masked inactive after dispatch.""" + _run_rank_mask_inactive_before_combine_test( + mpi_pool_executor, + dead_rank, + local_num_tokens, + top_k, + ) diff --git a/tests/unittest/_torch/modules/test_alltoall_watchdog.py b/tests/unittest/_torch/modules/test_alltoall_watchdog.py new file mode 100644 index 000000000000..b44140dbd21f --- /dev/null +++ b/tests/unittest/_torch/modules/test_alltoall_watchdog.py @@ -0,0 +1,350 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for AlltoAllWatchdog (WideEP fault tolerance, PR 1a.4).""" + +import threading +import time +from collections.abc import Callable +from types import SimpleNamespace + +import pytest +import torch + +from tensorrt_llm._torch.alltoall_watchdog import ( + DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S, + DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, + UNKNOWN_COMPLETION_FLAG, + AlltoAllWatchdog, + AlltoAllWatchdogTimeout, + CompletionFlagReadTimeout, +) +from tensorrt_llm._torch.modules.fused_moe.ep_group_health import EPGroupHealth +from tensorrt_llm._torch.modules.fused_moe.wide_ep_ft import get_wide_ep_ft_options + + +class FakeCompletionFlagReader: + """Thread-safe completion flag reader for pure-Python watchdog tests.""" + + def __init__(self, ep_size: int) -> None: + self._lock = threading.Lock() + self._flags = { + "dispatch": [0 for _ in range(ep_size)], + "combine": [0 for _ in range(ep_size)], + } + + def set_flags(self, phase: str, flags: list[int]) -> None: + with self._lock: + self._flags[phase] = list(flags) + + def read_completion_flags(self, phase: str) -> tuple[int, ...]: + with self._lock: + return tuple(self._flags[phase]) + + +class TimeoutCompletionFlagReader: + def read_completion_flags(self, phase: str) -> tuple[int, ...]: + raise CompletionFlagReadTimeout("blocked") + + +class OneGoodReadThenTimeoutReader: + def __init__(self, flags: tuple[int, ...]) -> None: + self._flags = flags + self._read_count = 0 + + def read_completion_flags(self, phase: str) -> tuple[int, ...]: + self._read_count += 1 + if self._read_count == 1: + return self._flags + raise CompletionFlagReadTimeout("blocked") + + +def _wait_for(predicate: Callable[[], bool], timeout_s: float = 1.0) -> None: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + if predicate(): + return + time.sleep(0.005) + raise AssertionError("condition was not reached before timeout") + + +def test_watchdog_completes_when_all_active_flags_arrive() -> None: + health = EPGroupHealth(4) + reader = FakeCompletionFlagReader(ep_size=4) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=4, + ep_rank=0, + completion_reader=reader, + timeout_s=0.2, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=1) + reader.set_flags("dispatch", [1, 1, 1, 1]) + assert watchdog.wait_until_idle(timeout_s=1.0) + + assert events == [] + assert health.all_active() is True + + +def test_watchdog_defaults_match_design_doc() -> None: + reader = FakeCompletionFlagReader(ep_size=1) + watchdog = AlltoAllWatchdog(ep_size=1, ep_rank=0, completion_reader=reader) + + assert watchdog._timeout_s == DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S + assert watchdog._poll_interval_s == DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S + + +def test_watchdog_stop_is_terminal() -> None: + reader = FakeCompletionFlagReader(ep_size=1) + watchdog = AlltoAllWatchdog( + ep_size=1, + ep_rank=0, + completion_reader=reader, + timeout_s=0.2, + poll_interval_s=0.005, + ) + watchdog.start() + watchdog.stop(timeout_s=1.0) + + with pytest.raises(RuntimeError, match="stopped AlltoAllWatchdog"): + watchdog.start() + with pytest.raises(RuntimeError, match="stopped AlltoAllWatchdog"): + watchdog.watch(phase="dispatch", expected_flag=1) + + +def test_wide_ep_ft_options_create_shared_health_when_enabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("TRTLLM_ENABLE_WIDE_EP_FT", "1") + model_config = SimpleNamespace( + extra_attrs={}, + mapping=SimpleNamespace(moe_ep_size=4), + ) + + health, timeout_s, poll_interval_s = get_wide_ep_ft_options(model_config) + health_again, timeout_again_s, poll_again_s = get_wide_ep_ft_options(model_config) + + assert isinstance(health, EPGroupHealth) + assert health_again is health + assert timeout_s == DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S + assert timeout_again_s == timeout_s + assert poll_interval_s == DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S + assert poll_again_s == poll_interval_s + + +def test_watchdog_timeout_reports_and_marks_missing_remote_ranks() -> None: + health = EPGroupHealth(4) + reader = FakeCompletionFlagReader(ep_size=4) + reader.set_flags("dispatch", [1, 0, 1, 0]) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=4, + ep_rank=0, + completion_reader=reader, + timeout_s=0.02, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=1) + _wait_for(lambda: len(events) == 1) + assert watchdog.wait_until_idle(timeout_s=1.0) + + event = events[0] + assert event.phase == "dispatch" + assert event.expected_flag == 1 + assert event.observed_flags == (1, 0, 1, 0) + assert event.missing_ranks == (1, 3) + assert event.marked_failed_ranks == (1, 3) + assert health.get_failed_ranks() == frozenset({1, 3}) + + +def test_watchdog_ignores_ranks_already_failed_in_health_mask() -> None: + health = EPGroupHealth(4) + assert health.mark_failed(2) is True + reader = FakeCompletionFlagReader(ep_size=4) + reader.set_flags("dispatch", [1, 1, 0, 1]) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=4, + ep_rank=0, + completion_reader=reader, + timeout_s=0.05, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=1) + assert watchdog.wait_until_idle(timeout_s=1.0) + + assert events == [] + assert health.get_failed_ranks() == frozenset({2}) + + +def test_watchdog_reports_local_missing_but_does_not_mark_local_failed() -> None: + health = EPGroupHealth(4) + reader = FakeCompletionFlagReader(ep_size=4) + reader.set_flags("combine", [0, 2, 2, 2]) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=4, + ep_rank=0, + completion_reader=reader, + timeout_s=0.02, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="combine", expected_flag=2) + _wait_for(lambda: len(events) == 1) + + event = events[0] + assert event.missing_ranks == (0,) + assert event.marked_failed_ranks == () + assert health.get_failed_ranks() == frozenset() + + +def test_watchdog_poll_timeout_without_snapshot_fails_closed() -> None: + health = EPGroupHealth(3) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=3, + ep_rank=0, + completion_reader=TimeoutCompletionFlagReader(), + timeout_s=0.02, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=1) + _wait_for(lambda: len(events) == 1) + + event = events[0] + assert event.poll_timed_out is True + assert event.observed_flags == (UNKNOWN_COMPLETION_FLAG,) * 3 + assert event.missing_ranks == (0, 1, 2) + assert event.marked_failed_ranks == () + assert health.all_active() is True + + +def test_watchdog_poll_timeout_with_prior_snapshot_marks_known_missing_rank() -> None: + health = EPGroupHealth(3) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=3, + ep_rank=0, + completion_reader=OneGoodReadThenTimeoutReader((1, 0, 1)), + timeout_s=0.02, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=1) + _wait_for(lambda: len(events) == 1) + + event = events[0] + assert event.poll_timed_out is True + assert event.observed_flags == (1, 0, 1) + assert event.missing_ranks == (1,) + assert event.marked_failed_ranks == (1,) + assert health.get_failed_ranks() == frozenset({1}) + + +def test_watchdog_preserves_fifo_order_and_clears_followups_after_timeout() -> None: + health = EPGroupHealth(3) + reader = FakeCompletionFlagReader(ep_size=3) + reader.set_flags("dispatch", [1, 0, 1]) + reader.set_flags("combine", [0, 0, 0]) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog( + ep_size=3, + ep_rank=0, + completion_reader=reader, + timeout_s=0.02, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=1) + watchdog.watch(phase="combine", expected_flag=2) + _wait_for(lambda: len(events) == 1) + assert watchdog.wait_until_idle(timeout_s=1.0) + time.sleep(0.05) + + assert len(events) == 1 + assert events[0].phase == "dispatch" + assert events[0].missing_ranks == (1,) + assert health.get_failed_ranks() == frozenset({1}) + + +def test_watchdog_from_workspace_reads_phase_specific_offsets() -> None: + ep_size = 3 + ep_rank = 1 + workspace = torch.zeros((ep_size, 64), dtype=torch.uint8) + metainfo = torch.zeros((10,), dtype=torch.int64) + metainfo_index = { + "DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX": 4, + "COMBINE_COMPLETION_FLAGS_OFFSET_INDEX": 5, + } + metainfo[4] = 4 + metainfo[5] = 20 + workspace[ep_rank, 4:16].view(torch.int32).copy_(torch.tensor([7, 7, 7], dtype=torch.int32)) + workspace[ep_rank, 20:32].view(torch.int32).copy_(torch.tensor([0, 8, 8], dtype=torch.int32)) + health = EPGroupHealth(ep_size) + events: list[AlltoAllWatchdogTimeout] = [] + + with AlltoAllWatchdog.from_workspace( + workspace=workspace, + metainfo=metainfo, + metainfo_index=metainfo_index, + ep_rank=ep_rank, + ep_size=ep_size, + timeout_s=0.02, + poll_interval_s=0.005, + health=health, + on_timeout=events.append, + ) as watchdog: + watchdog.watch(phase="dispatch", expected_flag=7) + assert watchdog.wait_until_idle(timeout_s=1.0) + + watchdog.watch(phase="combine", expected_flag=8) + _wait_for(lambda: len(events) == 1) + + assert events[0].phase == "combine" + assert events[0].missing_ranks == (0,) + assert events[0].marked_failed_ranks == (0,) + assert health.get_failed_ranks() == frozenset({0}) + + +def test_watchdog_rejects_active_mask_without_local_rank() -> None: + reader = FakeCompletionFlagReader(ep_size=4) + with AlltoAllWatchdog( + ep_size=4, + ep_rank=2, + completion_reader=reader, + timeout_s=0.1, + poll_interval_s=0.005, + ) as watchdog: + with pytest.raises(ValueError, match="local ep_rank"): + watchdog.watch(phase="dispatch", expected_flag=1, active_mask=0b1011)