Skip to content

Commit ea9466e

Browse files
committed
[None][feat] WideEP FT: add active_rank_mask to NVLink AlltoAll kernels
Eliminates the infinite-spin AlltoAll hang that turns a single GPU failure in a Wide-EP group into a 5-minute HangDetector fire + full restart. The dispatch and combine kernels now take a uint64[2] bitmask of currently-alive EP ranks; dead ranks are skipped on every completion-flag write/wait, peer recv_counter store, EPLB stats write, and per-token routing decision (dead-targeted slots collapse to the same -1 sentinel combine already uses for duplicates). The mask is optional on both torch ops; omitting it (or passing all-ones) produces bit-identical output to the pre-change kernel. kMaxRanks is bumped 64 -> 128 to cover NVL72 with headroom; kRankMaskWords = 2 names the kernel ABI explicitly. Tests cover (a) all-ones mask matches no-mask bit-for-bit, and (b) one rank masked dead -> surviving ranks complete dispatch+combine without hang, dead-targeted topk slots dropped, in tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py. Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent eddaa3a commit ea9466e

6 files changed

Lines changed: 543 additions & 10 deletions

File tree

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ __device__ __forceinline__ int compute_target_rank_id(int expert_id, int base, i
210210
return remainder + (expert_id - split) / base;
211211
}
212212

213+
// Test bit `rank` in a kRankMaskWords-wide little-endian uint64 bitmask.
214+
// Word 0 covers ranks 0..63, word 1 covers ranks 64..127, etc.
215+
// `rank >> 6` and `rank & 63` divide / modulo by 64.
216+
__device__ __forceinline__ bool is_rank_active(uint64_t const* mask, int rank)
217+
{
218+
return (mask[rank >> 6] >> (rank & 63)) & 1ULL;
219+
}
220+
213221
// ============================================================================
214222
// Helper Functions for Vectorized Memory Operations
215223
// ============================================================================
@@ -432,7 +440,12 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
432440
// Supports the non-divisible case where num_experts % ep_size != 0.
433441
int target_rank = compute_target_rank_id(expert_id, ep_base, ep_remainder);
434442

435-
if (already_copied & (1ULL << target_rank))
443+
// Skip duplicates AND dead ranks: both produce the same -1 sentinel that combine
444+
// checks via topk_send_indices[k] < 0. A token whose only target is dead is dropped
445+
// from this collective; higher-layer logic (EPLB redistribution) is responsible
446+
// for re-routing such tokens on subsequent iterations.
447+
bool const target_dead = !is_rank_active(ptrs.active_rank_mask, target_rank);
448+
if ((already_copied & (1ULL << target_rank)) || target_dead)
436449
{
437450
if (thread_idx == 0)
438451
{
@@ -511,20 +524,26 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
511524

512525
if (is_last_token)
513526
{
514-
// Store send_counters to recv_counters
527+
// Store send_counters to recv_counters.
528+
// Skip masked target ranks: their symmetric memory may be inaccessible.
515529
#pragma unroll 1 // No unroll as one iter is typically enough
516530
for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize)
517531
{
532+
if (!is_rank_active(ptrs.active_rank_mask, target_rank))
533+
continue;
518534
int send_count = ptrs.send_counters[target_rank];
519535
ptrs.recv_counters[target_rank][rank_id] = send_count;
520536
}
521537

522538
if constexpr (ENABLE_EPLB)
523539
{
524540
// Write local stats into peer buffers before the release fence below.
541+
// Skip masked target ranks for the same reason as above.
525542
#pragma unroll 1
526543
for (int target_rank = 0; target_rank < ep_size; ++target_rank)
527544
{
545+
if (!is_rank_active(ptrs.active_rank_mask, target_rank))
546+
continue;
528547
int* target_stats = ptrs.eplb_gathered_stats[target_rank];
529548
for (int expert_id = lane_id; expert_id < eplb_stats_num_experts; expert_id += warpSize)
530549
{
@@ -543,9 +562,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
543562
#else
544563
asm volatile("fence.acq_rel.sys;");
545564
#endif
565+
// Signal completion to all active peers; skip dead ranks (their symmetric memory
566+
// is unreachable).
546567
#pragma unroll 1 // No unroll as one iter is typically enough
547568
for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize)
548569
{
570+
if (!is_rank_active(ptrs.active_rank_mask, target_rank))
571+
continue;
549572
uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id];
550573
asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value));
551574

@@ -555,9 +578,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
555578
#endif
556579
}
557580

581+
// Wait for all active peers to signal; skip dead ranks (otherwise we would
582+
// spin forever — this is the bug the rank-mask is here to prevent).
558583
#pragma unroll 1 // No unroll
559584
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
560585
{
586+
if (!is_rank_active(ptrs.active_rank_mask, peer_rank))
587+
continue;
561588
bool flag_set = false;
562589
auto s = clock64();
563590
do
@@ -605,6 +632,10 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
605632
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
606633
TLLM_CHECK(params.local_num_tokens >= 0);
607634
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
635+
// The local rank must always be marked active in its own view of the mask;
636+
// otherwise the kernel itself would be running on a "dead" rank.
637+
TLLM_CHECK_WITH_INFO((params.active_rank_mask[params.ep_rank >> 6] >> (params.ep_rank & 63)) & 1ULL,
638+
"active_rank_mask must mark the local ep_rank (%d) as active", params.ep_rank);
608639

609640
// Prepare kernel pointers struct
610641
DispatchKernelPointers kernel_ptrs = {};
@@ -642,6 +673,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
642673
kernel_ptrs.topk_send_indices = params.topk_send_indices;
643674
kernel_ptrs.eplb_local_stats = params.eplb_local_stats;
644675

676+
// Copy active-rank bitmask into the kernel pointers struct
677+
for (int w = 0; w < kRankMaskWords; ++w)
678+
{
679+
kernel_ptrs.active_rank_mask[w] = params.active_rank_mask[w];
680+
}
681+
645682
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize();
646683

647684
// One block per token: grid_size == local_num_tokens. If 0, launch a single block to
@@ -1153,9 +1190,13 @@ __global__ void moeA2ACombineKernel(
11531190

11541191
if (blockIdx.x == 0)
11551192
{
1193+
// Signal readiness to all active peers; skip dead ranks (their symmetric memory
1194+
// is unreachable).
11561195
#pragma unroll 1 // No unroll
11571196
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
11581197
{
1198+
if (!is_rank_active(ptrs.active_rank_mask, peer_rank))
1199+
continue;
11591200
uint32_t* flag_addr = &ptrs.completion_flags[peer_rank][rank_id];
11601201
asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value));
11611202
#if ENABLE_DEBUG_PRINT
@@ -1165,9 +1206,13 @@ __global__ void moeA2ACombineKernel(
11651206
}
11661207
}
11671208

1209+
// Wait for all active peers to signal; skip dead ranks (otherwise we would spin
1210+
// forever — this is the bug the rank-mask is here to prevent).
11681211
#pragma unroll 1 // No unroll
11691212
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
11701213
{
1214+
if (!is_rank_active(ptrs.active_rank_mask, peer_rank))
1215+
continue;
11711216
bool flag_set = false;
11721217
auto s = clock64();
11731218
do
@@ -1273,6 +1318,10 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
12731318
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
12741319
TLLM_CHECK(params.local_num_tokens >= 0);
12751320
TLLM_CHECK(params.elements_per_token > 0);
1321+
// The local rank must always be marked active in its own view of the mask;
1322+
// otherwise the kernel itself would be running on a "dead" rank.
1323+
TLLM_CHECK_WITH_INFO((params.active_rank_mask[params.ep_rank >> 6] >> (params.ep_rank & 63)) & 1ULL,
1324+
"active_rank_mask must mark the local ep_rank (%d) as active", params.ep_rank);
12761325

12771326
// Configure kernel launch (one block per token).
12781327
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize();
@@ -1306,6 +1355,12 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
13061355
kernel_ptrs.topk_target_ranks = params.topk_target_ranks;
13071356
kernel_ptrs.topk_send_indices = params.topk_send_indices;
13081357

1358+
// Copy active-rank bitmask into the kernel pointers struct
1359+
for (int w = 0; w < kRankMaskWords; ++w)
1360+
{
1361+
kernel_ptrs.active_rank_mask[w] = params.active_rank_mask[w];
1362+
}
1363+
13091364
// stride_per_token: byte distance between tokens in the recv buffer.
13101365
// FP8 external payload: EPT × 1 (compact FP8 layout)
13111366
// FP8 in-place / non-FP8: EPT × sizeof(PayloadT) (payload-dtype stride)

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ namespace kernels::moe_comm
2626
{
2727

2828
// Configuration constants
29-
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
30-
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
31-
static constexpr int kMaxRanks = 64; // Maximum supported EP size
29+
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
30+
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
31+
static constexpr int kMaxRanks = 128; // Maximum supported EP size (covers NVL72 with headroom)
32+
static constexpr int kRankMaskWords = 2; // uint64 words to hold the active-rank bitmask
33+
// (kRankMaskWords * 64 must be >= kMaxRanks)
34+
static_assert(kRankMaskWords * 64 >= kMaxRanks, "active_rank_mask too small for kMaxRanks");
3235

3336
// Describes a single payload type to be communicated
3437
struct PayloadDescriptor
@@ -65,6 +68,12 @@ struct DispatchKernelPointers
6568
// Optional: Statistics for EPLB
6669
int const* eplb_local_stats; // [eplb_stats_num_experts]
6770
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank
71+
72+
// Active-rank bitmask: bit i set => rank i is alive and participates in this collective.
73+
// Word 0 covers ranks 0..63; word 1 covers ranks 64..127. Tokens routed to a masked
74+
// rank are dropped (topk_*[k] = -1); flag writes/waits to/from masked peers are skipped.
75+
// The local rank's own bit must always be set; this is checked at launch time.
76+
uint64_t active_rank_mask[kRankMaskWords];
6877
};
6978

7079
// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers
@@ -82,6 +91,11 @@ struct CombineKernelPointers
8291
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
8392
int const* topk_target_ranks; // target rank per k, -1 for duplicates
8493
int const* topk_send_indices; // dst index per k, -1 for duplicates
94+
95+
// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. Combine skips flag
96+
// writes/waits to/from masked peers; per-token accumulation uses topk_send_indices[k] < 0
97+
// (set by dispatch) to skip dead-targeted slots, so no explicit mask check is needed there.
98+
uint64_t active_rank_mask[kRankMaskWords];
8599
};
86100

87101
// Dispatch phase parameters
@@ -125,6 +139,11 @@ struct MoeA2ADispatchParams
125139
int const* eplb_local_stats; // [eplb_stats_num_experts]
126140
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank
127141

142+
// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function
143+
// copies these words into the kernel pointers struct. Defaults to all-ones for
144+
// backwards-compatible "no masking" behavior.
145+
uint64_t active_rank_mask[kRankMaskWords];
146+
128147
// CUDA stream
129148
cudaStream_t stream;
130149
};
@@ -170,6 +189,11 @@ struct MoeA2ACombineParams
170189
// rank has signaled the target rank
171190
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
172191

192+
// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function
193+
// copies these words into the kernel pointers struct. Defaults to all-ones for
194+
// backwards-compatible "no masking" behavior.
195+
uint64_t active_rank_mask[kRankMaskWords];
196+
173197
// CUDA stream
174198
cudaStream_t stream;
175199
};

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,40 @@ inline size_t alignOffset(size_t offset, size_t alignment)
4242
return (offset + alignment - 1) & ~(alignment - 1);
4343
}
4444

45+
// Resolve an optional rank-mask tensor into a fixed-width uint64 array.
46+
// If the caller did not provide a mask, default to "all ranks active" (all bits set), which
47+
// reproduces the pre-fault-tolerance behavior bit-for-bit.
48+
//
49+
// On failure (wrong dtype / device / shape), throws via TORCH_CHECK so the error surfaces
50+
// at the Python op boundary rather than the kernel launch.
51+
inline void resolveActiveRankMask(torch::optional<torch::Tensor> const& maskTensor, int64_t epRank,
52+
uint64_t (&out)[tensorrt_llm::kernels::moe_comm::kRankMaskWords])
53+
{
54+
using tensorrt_llm::kernels::moe_comm::kRankMaskWords;
55+
if (!maskTensor.has_value() || !maskTensor.value().defined())
56+
{
57+
for (int w = 0; w < kRankMaskWords; ++w)
58+
{
59+
out[w] = ~uint64_t{0};
60+
}
61+
return;
62+
}
63+
torch::Tensor const& t = maskTensor.value();
64+
TORCH_CHECK(t.is_cpu(), "active_rank_mask must be a CPU tensor");
65+
TORCH_CHECK(t.scalar_type() == torch::kUInt64, "active_rank_mask must have dtype uint64");
66+
TORCH_CHECK(t.dim() == 1, "active_rank_mask must be a 1D tensor");
67+
TORCH_CHECK(t.numel() == kRankMaskWords, "active_rank_mask must have exactly ", kRankMaskWords, " uint64 elements");
68+
TORCH_CHECK(t.is_contiguous(), "active_rank_mask must be contiguous");
69+
auto const* src = static_cast<uint64_t const*>(t.const_data_ptr());
70+
for (int w = 0; w < kRankMaskWords; ++w)
71+
{
72+
out[w] = src[w];
73+
}
74+
// Local rank's bit must be set; otherwise the kernel would be running on a "dead" rank.
75+
TORCH_CHECK((out[epRank >> 6] >> (epRank & 63)) & 1ULL, "active_rank_mask must mark the local ep_rank (", epRank,
76+
") as active");
77+
}
78+
4579
// Calculate auxiliary data offsets
4680
MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNumExperts)
4781
{
@@ -181,7 +215,8 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,
181215
std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
182216
torch::Tensor const& tokenSelectedExperts, std::vector<torch::Tensor> const& inputPayloads,
183217
torch::Tensor const& workspace, torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank,
184-
int64_t epSize, int64_t topK, int64_t numExperts, torch::optional<torch::Tensor> eplbLocalStats)
218+
int64_t epSize, int64_t topK, int64_t numExperts, torch::optional<torch::Tensor> eplbLocalStats,
219+
torch::optional<torch::Tensor> activeRankMask)
185220
{
186221
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
187222
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
@@ -360,6 +395,10 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
360395
params.eplb_local_stats = nullptr;
361396
}
362397

398+
// Resolve the optional active-rank mask. Default (no mask) = all bits set, which
399+
// exactly reproduces the pre-fault-tolerance kernel behavior.
400+
resolveActiveRankMask(activeRankMask, epRank, params.active_rank_mask);
401+
363402
params.stream = at::cuda::getCurrentCUDAStream();
364403

365404
// Prepare for dispatch (zero counters/indices and increment flag_val)
@@ -413,7 +452,8 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
413452
// In both cases, the combine kernel reads from the workspace at 'combinePayloadOffset'.
414453
torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumTokens, torch::Tensor const& workspace,
415454
torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK,
416-
int64_t combinePayloadOffset, bool payloadInWorkspace, bool useLowPrecision = false)
455+
int64_t combinePayloadOffset, bool payloadInWorkspace, bool useLowPrecision = false,
456+
torch::optional<torch::Tensor> activeRankMask = torch::nullopt)
417457
{
418458
using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams;
419459
using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch;
@@ -520,6 +560,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
520560
params.recv_buffers[target_rank] = target_workspace_ptr + combinePayloadOffset;
521561
}
522562

563+
// Resolve the optional active-rank mask. Default (no mask) = all bits set.
564+
resolveActiveRankMask(activeRankMask, epRank, params.active_rank_mask);
565+
523566
params.stream = at::cuda::getCurrentCUDAStream();
524567

525568
moe_a2a_prepare_combine_launch(params);
@@ -613,12 +656,14 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
613656
"moe_a2a_dispatch(Tensor token_selected_experts, Tensor[] input_payloads, "
614657
"Tensor(a!->*) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, "
615658
"int ep_rank, int ep_size, int top_k, int num_experts, "
616-
"Tensor? eplb_local_stats=None) -> (Tensor(a!)[], int, Tensor(a!))");
659+
"Tensor? eplb_local_stats=None, "
660+
"Tensor? active_rank_mask=None) -> (Tensor(a!)[], int, Tensor(a!))");
617661
module.def(
618662
"moe_a2a_combine(Tensor(a) payload, int local_num_tokens,"
619663
"Tensor(a!) workspace, Tensor metainfo, int runtime_max_tokens_per_rank, "
620664
"int ep_rank, int ep_size, int top_k, int combine_payload_offset, "
621-
"bool payload_in_workspace, bool use_low_precision=False) -> Tensor");
665+
"bool payload_in_workspace, bool use_low_precision=False, "
666+
"Tensor? active_rank_mask=None) -> Tensor");
622667
module.def(
623668
"moe_a2a_initialize(Tensor(a!) workspace, int ep_rank, int ep_size, int max_num_tokens_per_rank, "
624669
"int? eplb_stats_num_experts=None) -> Tensor");

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def _(
448448
top_k: int,
449449
num_experts: int,
450450
eplb_local_stats: Optional[torch.Tensor] = None,
451+
active_rank_mask: Optional[torch.Tensor] = None,
451452
) -> Tuple[List[torch.Tensor], int, torch.Tensor]:
452453
recv_tensors: List[torch.Tensor] = []
453454
for payload in input_payloads:
@@ -478,6 +479,7 @@ def _(
478479
combine_payload_offset: int,
479480
payload_in_workspace: bool,
480481
use_low_precision: bool = False,
482+
active_rank_mask: Optional[torch.Tensor] = None,
481483
) -> torch.Tensor:
482484
return payload.new_empty((local_num_tokens, payload.shape[2]))
483485

tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class NVLinkOneSided(Communication):
5151
"""
5252

5353
# Constants from C++ (must match moeAlltoAllKernels.h)
54-
MAX_RANKS = 64
54+
MAX_RANKS = 128
5555
MAX_TOP_K = 8
5656
MAX_PAYLOADS = 8
5757

0 commit comments

Comments
 (0)