Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ __device__ __forceinline__ int compute_target_rank_id(int expert_id, int base, i
return remainder + (expert_id - split) / base;
}

// Test bit `rank` in a kRankMaskWords-wide little-endian uint64 bitmask.
// Word 0 covers ranks 0..63, word 1 covers ranks 64..127, etc.
// `rank >> 6` and `rank & 63` divide / modulo by 64.
__device__ __forceinline__ bool is_rank_active(uint64_t const* mask, int rank)
{
return (mask[rank >> 6] >> (rank & 63)) & 1ULL;
}

// ============================================================================
// Helper Functions for Vectorized Memory Operations
// ============================================================================
Expand Down Expand Up @@ -416,7 +424,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int* smem_topk_target_ranks = smem;
int* smem_topk_send_indices = smem + TOP_K;

uint64_t already_copied = 0;
uint64_t already_copied[kRankMaskWords] = {};
// Precompute the ceil/floor partition parameters once per thread, outside the
// per-token TOP_K loop. The fast path (remainder == 0) then collapses to a single
// integer divide per call, matching the pre-PR uniform-partition cost exactly.
Expand All @@ -432,7 +440,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
// Supports the non-divisible case where num_experts % ep_size != 0.
int target_rank = compute_target_rank_id(expert_id, ep_base, ep_remainder);

if (already_copied & (1ULL << target_rank))
// Skip duplicates AND dead ranks: both produce the same -1 sentinel that combine
// checks via topk_send_indices[k] < 0. A token whose only target is dead is dropped
// from this collective; higher-layer logic (EPLB redistribution) is responsible
// for re-routing such tokens on subsequent iterations.
int const mask_word = target_rank >> 6;
uint64_t const mask_bit = 1ULL << (target_rank & 63);
bool const target_already_copied = already_copied[mask_word] & mask_bit;
bool const target_dead = !is_rank_active(ptrs.active_rank_mask, target_rank);
if (target_already_copied || target_dead)
{
if (thread_idx == 0)
{
Expand All @@ -457,7 +473,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
already_copied |= 1ULL << target_rank;
already_copied[mask_word] |= mask_bit;
}
// Sync before dispatching data
ThreadingPolicy::sync();
Expand Down Expand Up @@ -511,20 +527,26 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [

if (is_last_token)
{
// Store send_counters to recv_counters
// Store send_counters to recv_counters.
// Skip masked target ranks: their symmetric memory may be inaccessible.
#pragma unroll 1 // No unroll as one iter is typically enough
for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize)
{
if (!is_rank_active(ptrs.active_rank_mask, target_rank))
continue;
int send_count = ptrs.send_counters[target_rank];
ptrs.recv_counters[target_rank][rank_id] = send_count;
}

if constexpr (ENABLE_EPLB)
{
// Write local stats into peer buffers before the release fence below.
// Skip masked target ranks for the same reason as above.
#pragma unroll 1
for (int target_rank = 0; target_rank < ep_size; ++target_rank)
{
if (!is_rank_active(ptrs.active_rank_mask, target_rank))
continue;
int* target_stats = ptrs.eplb_gathered_stats[target_rank];
for (int expert_id = lane_id; expert_id < eplb_stats_num_experts; expert_id += warpSize)
{
Expand All @@ -543,9 +565,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
#else
asm volatile("fence.acq_rel.sys;");
#endif
// Signal completion to all active peers; skip dead ranks (their symmetric memory
// is unreachable).
#pragma unroll 1 // No unroll as one iter is typically enough
for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize)
{
if (!is_rank_active(ptrs.active_rank_mask, target_rank))
continue;
uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id];
asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value));

Expand All @@ -555,9 +581,13 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
#endif
}

// Wait for all active peers to signal; skip dead ranks (otherwise we would
// spin forever — this is the bug the rank-mask is here to prevent).
#pragma unroll 1 // No unroll
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
if (!is_rank_active(ptrs.active_rank_mask, peer_rank))
continue;
bool flag_set = false;
auto s = clock64();
do
Expand Down Expand Up @@ -603,8 +633,13 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.ep_rank >= 0 && params.ep_rank < params.ep_size);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
// The local rank must always be marked active in its own view of the mask;
// otherwise the kernel itself would be running on a "dead" rank.
TLLM_CHECK_WITH_INFO((params.active_rank_mask[params.ep_rank >> 6] >> (params.ep_rank & 63)) & 1ULL,
"active_rank_mask must mark the local ep_rank (%d) as active", params.ep_rank);

// Prepare kernel pointers struct
DispatchKernelPointers kernel_ptrs = {};
Expand Down Expand Up @@ -642,6 +677,12 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
kernel_ptrs.topk_send_indices = params.topk_send_indices;
kernel_ptrs.eplb_local_stats = params.eplb_local_stats;

// Copy active-rank bitmask into the kernel pointers struct
for (int w = 0; w < kRankMaskWords; ++w)
{
kernel_ptrs.active_rank_mask[w] = params.active_rank_mask[w];
}

int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize();

// One block per token: grid_size == local_num_tokens. If 0, launch a single block to
Expand Down Expand Up @@ -700,7 +741,7 @@ __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, i
{
int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k];
int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k];
if (dst_idx < 0)
if (dst_idx < 0 || !is_rank_active(ptrs.active_rank_mask, target_rank))
{
acc[k].fill(0.0f);
continue;
Expand All @@ -725,8 +766,12 @@ __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, i
#pragma unroll
for (int k = 0; k < TOP_K; ++k)
{
if (ptrs.topk_send_indices[local_token_idx * TOP_K + k] < 0)
int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k];
int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k];
if (dst_idx < 0 || !is_rank_active(ptrs.active_rank_mask, target_rank))
{
continue; // acc[k] already holds 0.0f from fill() above
}
#pragma unroll
for (int j = elems_per_vec - 1; j >= 0; --j)
acc[k][j] = static_cast<float>(reinterpret_cast<InT const*>(&acc[k])[j]);
Expand Down Expand Up @@ -1153,9 +1198,13 @@ __global__ void moeA2ACombineKernel(

if (blockIdx.x == 0)
{
// Signal readiness to all active peers; skip dead ranks (their symmetric memory
// is unreachable).
#pragma unroll 1 // No unroll
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
if (!is_rank_active(ptrs.active_rank_mask, peer_rank))
continue;
uint32_t* flag_addr = &ptrs.completion_flags[peer_rank][rank_id];
asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value));
#if ENABLE_DEBUG_PRINT
Expand All @@ -1165,9 +1214,13 @@ __global__ void moeA2ACombineKernel(
}
}

// Wait for all active peers to signal; skip dead ranks (otherwise we would spin
// forever — this is the bug the rank-mask is here to prevent).
#pragma unroll 1 // No unroll
for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize)
{
if (!is_rank_active(ptrs.active_rank_mask, peer_rank))
continue;
bool flag_set = false;
auto s = clock64();
do
Expand Down Expand Up @@ -1271,8 +1324,13 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.ep_rank >= 0 && params.ep_rank < params.ep_size);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.elements_per_token > 0);
// The local rank must always be marked active in its own view of the mask;
// otherwise the kernel itself would be running on a "dead" rank.
TLLM_CHECK_WITH_INFO((params.active_rank_mask[params.ep_rank >> 6] >> (params.ep_rank & 63)) & 1ULL,
"active_rank_mask must mark the local ep_rank (%d) as active", params.ep_rank);

// Configure kernel launch (one block per token).
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize();
Expand Down Expand Up @@ -1306,6 +1364,12 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
kernel_ptrs.topk_target_ranks = params.topk_target_ranks;
kernel_ptrs.topk_send_indices = params.topk_send_indices;

// Copy active-rank bitmask into the kernel pointers struct
for (int w = 0; w < kRankMaskWords; ++w)
{
kernel_ptrs.active_rank_mask[w] = params.active_rank_mask[w];
}

// stride_per_token: byte distance between tokens in the recv buffer.
// FP8 external payload: EPT × 1 (compact FP8 layout)
// FP8 in-place / non-FP8: EPT × sizeof(PayloadT) (payload-dtype stride)
Expand Down
30 changes: 27 additions & 3 deletions cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ namespace kernels::moe_comm
{

// Configuration constants
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
static constexpr int kMaxRanks = 64; // Maximum supported EP size
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
static constexpr int kMaxRanks = 128; // Maximum supported EP size (covers NVL72 with headroom)
static constexpr int kRankMaskWords = 2; // uint64 words to hold the active-rank bitmask
// (kRankMaskWords * 64 must be >= kMaxRanks)
static_assert(kRankMaskWords * 64 >= kMaxRanks, "active_rank_mask too small for kMaxRanks");

// Describes a single payload type to be communicated
struct PayloadDescriptor
Expand Down Expand Up @@ -65,6 +68,12 @@ struct DispatchKernelPointers
// Optional: Statistics for EPLB
int const* eplb_local_stats; // [eplb_stats_num_experts]
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank

// Active-rank bitmask: bit i set => rank i is alive and participates in this collective.
// Word 0 covers ranks 0..63; word 1 covers ranks 64..127. Tokens routed to a masked
// rank are dropped (topk_*[k] = -1); flag writes/waits to/from masked peers are skipped.
// The local rank's own bit must always be set; this is checked at launch time.
uint64_t active_rank_mask[kRankMaskWords];
};

// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers
Expand All @@ -82,6 +91,11 @@ struct CombineKernelPointers
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int const* topk_target_ranks; // target rank per k, -1 for duplicates
int const* topk_send_indices; // dst index per k, -1 for duplicates

// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. Combine skips flag
// writes/waits to/from masked peers and also skips per-token accumulation for ranks that
// become inactive between dispatch and combine.
uint64_t active_rank_mask[kRankMaskWords];
Comment thread
chienchunhung marked this conversation as resolved.
};

// Dispatch phase parameters
Expand Down Expand Up @@ -125,6 +139,11 @@ struct MoeA2ADispatchParams
int const* eplb_local_stats; // [eplb_stats_num_experts]
int* eplb_gathered_stats[kMaxRanks]; // [ep_size, eplb_stats_num_experts] per rank

// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function
// copies these words into the kernel pointers struct. Defaults to all-ones for
// backwards-compatible "no masking" behavior.
uint64_t active_rank_mask[kRankMaskWords] = {~uint64_t{0}, ~uint64_t{0}};

// CUDA stream
cudaStream_t stream;
};
Expand Down Expand Up @@ -170,6 +189,11 @@ struct MoeA2ACombineParams
// rank has signaled the target rank
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)

// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function
// copies these words into the kernel pointers struct. Defaults to all-ones for
// backwards-compatible "no masking" behavior.
uint64_t active_rank_mask[kRankMaskWords] = {~uint64_t{0}, ~uint64_t{0}};

// CUDA stream
cudaStream_t stream;
};
Expand Down
Loading
Loading