11/*
2- * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+ * Copyright (c) 2025-2026 , NVIDIA CORPORATION. All rights reserved.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -159,18 +159,6 @@ using tensorrt_llm::common::launchWithPdlWhenEnabled;
159159 } \
160160 }
161161
162- #define SWITCH_POLICY (one_block_per_token, POLICY , ...) \
163- if (one_block_per_token) \
164- { \
165- using POLICY = BlockPolicy; \
166- __VA_ARGS__ \
167- } \
168- else \
169- { \
170- using POLICY = WarpPolicy; \
171- __VA_ARGS__ \
172- }
173-
174162#if DISABLE_TIMEOUT
175163#define check_timeout (s ) false
176164#else
@@ -201,29 +189,6 @@ __device__ int compute_target_rank_id(int expert_id, int num_experts_per_rank)
201189// Helper Functions for Vectorized Memory Operations
202190// ============================================================================
203191
204- struct WarpPolicy
205- {
206- __device__ static int stride ()
207- {
208- return warpSize ;
209- }
210-
211- __device__ static int offset ()
212- {
213- return (threadIdx .x % warpSize );
214- }
215-
216- __device__ static int token_idx ()
217- {
218- return (blockIdx .x * blockDim .x + threadIdx .x ) / warpSize ;
219- }
220-
221- __device__ static void sync ()
222- {
223- __syncwarp ();
224- }
225- };
226-
227192struct BlockPolicy
228193{
229194 __device__ static int stride ()
@@ -421,22 +386,10 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
421386 if (local_token_idx >= local_num_tokens)
422387 return ;
423388
424- // Prepare per-policy shared-memory tiles for this token
389+ // One block per token: a single shared-memory tile is reused by the entire CTA.
425390 extern __shared__ int smem[];
426- int * smem_topk_target_ranks;
427- int * smem_topk_send_indices;
428- int warps_per_block = blockDim .x / warpSize ;
429- if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
430- {
431- int lane_id = threadIdx .x / warpSize ;
432- smem_topk_target_ranks = smem + lane_id * TOP_K ;
433- smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K ;
434- }
435- else
436- {
437- smem_topk_target_ranks = smem;
438- smem_topk_send_indices = smem + TOP_K ;
439- }
391+ int * smem_topk_target_ranks = smem;
392+ int * smem_topk_send_indices = smem + TOP_K ;
440393
441394 uint64_t already_copied = 0 ;
442395 int num_experts_per_rank = num_experts / ep_size;
@@ -660,44 +613,21 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
660613 kernel_ptrs.eplb_local_stats = params.eplb_local_stats ;
661614
662615 int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize ();
663- constexpr int kWarpSize = 32 ;
664- int const kWarpsPerBlock = kBlockSize / kWarpSize ;
665616
666- // Configure kernel launch
667- if (params.one_block_per_token )
617+ // One block per token: grid_size == local_num_tokens. If 0, launch a single block to
618+ // keep the synchronization path alive.
619+ int grid_size = params.local_num_tokens ;
620+ if (grid_size == 0 )
668621 {
669- int grid_size = params.local_num_tokens ;
670- // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
671- if (grid_size == 0 )
672- {
673- grid_size = 1 ;
674- }
675- int shared_bytes = 2 * params.top_k * (int ) sizeof (int );
676- SWITCH_BOOL (params.enable_eplb , EPLB_STATS , SWITCH_TOP_K (params.top_k , TOP_K , {
677- auto kernel_fn = moeA2ADispatchKernel<BlockPolicy, TOP_K , EPLB_STATS >;
678- launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , kernel_fn, grid_size, kBlockSize , shared_bytes,
679- params.stream , params.token_selected_experts , kernel_ptrs, params.num_payloads ,
680- params.max_tokens_per_rank , params.local_num_tokens , params.ep_rank , params.ep_size , params.num_experts ,
681- params.eplb_stats_num_experts );
682- }))
683- }
684- else
685- {
686- int grid_size = ceilDiv (params.local_num_tokens , kWarpsPerBlock );
687- // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
688- if (grid_size == 0 )
689- {
690- grid_size = 1 ;
691- }
692- int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int ) sizeof (int );
693- SWITCH_BOOL (params.enable_eplb , EPLB_STATS , SWITCH_TOP_K (params.top_k , TOP_K , {
694- auto kernel_fn = moeA2ADispatchKernel<WarpPolicy, TOP_K , EPLB_STATS >;
695- launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , kernel_fn, grid_size, kBlockSize , shared_bytes,
696- params.stream , params.token_selected_experts , kernel_ptrs, params.num_payloads ,
697- params.max_tokens_per_rank , params.local_num_tokens , params.ep_rank , params.ep_size , params.num_experts ,
698- params.eplb_stats_num_experts );
699- }))
622+ grid_size = 1 ;
700623 }
624+ int shared_bytes = 2 * params.top_k * (int ) sizeof (int );
625+ SWITCH_BOOL (params.enable_eplb , EPLB_STATS , SWITCH_TOP_K (params.top_k , TOP_K , {
626+ auto kernel_fn = moeA2ADispatchKernel<BlockPolicy, TOP_K , EPLB_STATS >;
627+ launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , kernel_fn, grid_size, kBlockSize , shared_bytes, params.stream ,
628+ params.token_selected_experts , kernel_ptrs, params.num_payloads , params.max_tokens_per_rank ,
629+ params.local_num_tokens , params.ep_rank , params.ep_size , params.num_experts , params.eplb_stats_num_experts );
630+ }))
701631}
702632
703633// ============================================================================
@@ -1272,17 +1202,14 @@ __global__ void moeA2ACombineKernel(
12721202void moe_a2a_prepare_combine_launch (MoeA2ACombineParams const & params)
12731203{
12741204 constexpr int kBlockSize = 256 ;
1275- constexpr int kWarpsPerBlock = kBlockSize / 32 ; // 8 warps per block
12761205
12771206 // FP8 in-place (payload_in_workspace=true, prepare_payload==nullptr): each CTA writes
12781207 // FP8 at the BF16-stride position, so CTAs never race — all tokens must be processed.
12791208 // Copy path with null payload is a no-op; 1 block suffices for the flag increment only.
12801209 int global_token_num = (params.use_low_precision || params.prepare_payload != nullptr )
12811210 ? params.ep_size * params.max_tokens_per_rank
12821211 : 1 ;
1283- int grid_size_warp = ceilDiv (global_token_num, kWarpsPerBlock );
1284- int grid_size_block = global_token_num; // one block per token
1285- int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
1212+ int grid = global_token_num; // one block per token
12861213
12871214 uint8_t * recv_buffer_bytes = static_cast <uint8_t *>(const_cast <void *>(params.recv_buffers [params.ep_rank ]));
12881215 void const * payload = params.prepare_payload ;
@@ -1297,8 +1224,7 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params)
12971224 int const stride_per_token = low_precision_staged
12981225 ? params.elements_per_token
12991226 : params.elements_per_token * static_cast <int >(sizeof (SrcT));
1300- auto kernel_fn = params.one_block_per_token ? moeA2APrepareCombineKernel<BlockPolicy, LOW_PRECISION , SrcT>
1301- : moeA2APrepareCombineKernel<WarpPolicy, LOW_PRECISION , SrcT>;
1227+ auto kernel_fn = moeA2APrepareCombineKernel<BlockPolicy, LOW_PRECISION , SrcT>;
13021228 launchWithPdlWhenEnabled (" moeA2APrepareCombineKernel" , kernel_fn, grid, kBlockSize , 0 , params.stream ,
13031229 recv_buffer_bytes, payload, params.elements_per_token , params.ep_size , params.max_tokens_per_rank ,
13041230 params.flag_val , params.recv_counters , stride_per_token);
@@ -1318,19 +1244,13 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
13181244 TLLM_CHECK (params.local_num_tokens >= 0 );
13191245 TLLM_CHECK (params.elements_per_token > 0 );
13201246
1321- // Configure kernel launch
1247+ // Configure kernel launch (one block per token).
13221248 int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize ();
1323- int const kWarpsPerBlock = kBlockSize / 32 ; // warpSize
1324- int grid_size_warp = ceilDiv (params.local_num_tokens , kWarpsPerBlock );
1325- int grid_size_block = params.local_num_tokens ;
1249+ int grid = params.local_num_tokens ;
13261250 // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
1327- if (grid_size_warp == 0 )
1251+ if (grid == 0 )
13281252 {
1329- grid_size_warp = 1 ;
1330- }
1331- if (grid_size_block == 0 )
1332- {
1333- grid_size_block = 1 ;
1253+ grid = 1 ;
13341254 }
13351255
13361256 // Prepare kernel pointers struct for combine
@@ -1356,8 +1276,6 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
13561276 kernel_ptrs.topk_target_ranks = params.topk_target_ranks ;
13571277 kernel_ptrs.topk_send_indices = params.topk_send_indices ;
13581278
1359- int grid = params.one_block_per_token ? grid_size_block : grid_size_warp;
1360-
13611279 // stride_per_token: byte distance between tokens in the recv buffer.
13621280 // FP8 external payload: EPT × 1 (compact FP8 layout)
13631281 // FP8 in-place / non-FP8: EPT × sizeof(PayloadT) (payload-dtype stride)
@@ -1374,13 +1292,11 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
13741292
13751293 // Launch appropriate kernel with compact macros
13761294 SWITCH_DTYPE (effective_dtype, TKernelType, {
1377- SWITCH_POLICY (params.one_block_per_token , Policy, {
1378- SWITCH_TOP_K (params.top_k , TOP_K , {
1379- auto kernel_fn = moeA2ACombineKernel<TKernelType, Policy, TOP_K >;
1380- launchWithPdlWhenEnabled (" moeA2ACombineKernel" , kernel_fn, grid, kBlockSize , 0 , params.stream ,
1381- kernel_ptrs, params.max_tokens_per_rank , params.elements_per_token , params.local_num_tokens ,
1382- params.ep_rank , params.ep_size , stride_per_token);
1383- });
1295+ SWITCH_TOP_K (params.top_k , TOP_K , {
1296+ auto kernel_fn = moeA2ACombineKernel<TKernelType, BlockPolicy, TOP_K >;
1297+ launchWithPdlWhenEnabled (" moeA2ACombineKernel" , kernel_fn, grid, kBlockSize , 0 , params.stream , kernel_ptrs,
1298+ params.max_tokens_per_rank , params.elements_per_token , params.local_num_tokens , params.ep_rank ,
1299+ params.ep_size , stride_per_token);
13841300 });
13851301 });
13861302}
0 commit comments