Skip to content

Commit e776204

Browse files
committed
Address active rank mask review comments
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent ea9466e commit e776204

4 files changed

Lines changed: 97 additions & 92 deletions

File tree

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
424424
int* smem_topk_target_ranks = smem;
425425
int* smem_topk_send_indices = smem + TOP_K;
426426

427-
uint64_t already_copied = 0;
427+
uint64_t already_copied[kRankMaskWords] = {};
428428
// Precompute the ceil/floor partition parameters once per thread, outside the
429429
// per-token TOP_K loop. The fast path (remainder == 0) then collapses to a single
430430
// integer divide per call, matching the pre-PR uniform-partition cost exactly.
@@ -444,8 +444,11 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
444444
// checks via topk_send_indices[k] < 0. A token whose only target is dead is dropped
445445
// from this collective; higher-layer logic (EPLB redistribution) is responsible
446446
// for re-routing such tokens on subsequent iterations.
447+
int const mask_word = target_rank >> 6;
448+
uint64_t const mask_bit = 1ULL << (target_rank & 63);
449+
bool const target_already_copied = already_copied[mask_word] & mask_bit;
447450
bool const target_dead = !is_rank_active(ptrs.active_rank_mask, target_rank);
448-
if ((already_copied & (1ULL << target_rank)) || target_dead)
451+
if (target_already_copied || target_dead)
449452
{
450453
if (thread_idx == 0)
451454
{
@@ -470,7 +473,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
470473
smem_topk_target_ranks[k] = target_rank;
471474
smem_topk_send_indices[k] = dst_token_idx;
472475
}
473-
already_copied |= 1ULL << target_rank;
476+
already_copied[mask_word] |= mask_bit;
474477
}
475478
// Sync before dispatching data
476479
ThreadingPolicy::sync();
@@ -630,6 +633,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
630633
// Validate parameters
631634
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
632635
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
636+
TLLM_CHECK(params.ep_rank >= 0 && params.ep_rank < params.ep_size);
633637
TLLM_CHECK(params.local_num_tokens >= 0);
634638
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
635639
// The local rank must always be marked active in its own view of the mask;
@@ -1316,6 +1320,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
13161320
// Validate parameters
13171321
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
13181322
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
1323+
TLLM_CHECK(params.ep_rank >= 0 && params.ep_rank < params.ep_size);
13191324
TLLM_CHECK(params.local_num_tokens >= 0);
13201325
TLLM_CHECK(params.elements_per_token > 0);
13211326
// The local rank must always be marked active in its own view of the mask;

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ struct MoeA2ADispatchParams
142142
// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function
143143
// copies these words into the kernel pointers struct. Defaults to all-ones for
144144
// backwards-compatible "no masking" behavior.
145-
uint64_t active_rank_mask[kRankMaskWords];
145+
uint64_t active_rank_mask[kRankMaskWords] = {~uint64_t{0}, ~uint64_t{0}};
146146

147147
// CUDA stream
148148
cudaStream_t stream;
@@ -192,7 +192,7 @@ struct MoeA2ACombineParams
192192
// Active-rank bitmask: see DispatchKernelPointers::active_rank_mask. The launch function
193193
// copies these words into the kernel pointers struct. Defaults to all-ones for
194194
// backwards-compatible "no masking" behavior.
195-
uint64_t active_rank_mask[kRankMaskWords];
195+
uint64_t active_rank_mask[kRankMaskWords] = {~uint64_t{0}, ~uint64_t{0}};
196196

197197
// CUDA stream
198198
cudaStream_t stream;

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ inline void resolveActiveRankMask(torch::optional<torch::Tensor> const& maskTens
5252
uint64_t (&out)[tensorrt_llm::kernels::moe_comm::kRankMaskWords])
5353
{
5454
using tensorrt_llm::kernels::moe_comm::kRankMaskWords;
55+
using tensorrt_llm::kernels::moe_comm::kMaxRanks;
56+
TORCH_CHECK(
57+
epRank >= 0 && epRank < kMaxRanks, "epRank must be in the range [0, ", kMaxRanks, ") for active_rank_mask");
5558
if (!maskTensor.has_value() || !maskTensor.value().defined())
5659
{
5760
for (int w = 0; w < kRankMaskWords; ++w)
@@ -151,11 +154,14 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu
151154
torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens,
152155
torch::optional<int64_t> eplbStatsNumExperts)
153156
{
157+
using tensorrt_llm::kernels::moe_comm::kMaxRanks;
158+
154159
// Validate inputs
155160
CHECK_TH_CUDA(workspace);
156161
CHECK_TYPE(workspace, torch::kUInt8);
157162
TORCH_CHECK(workspace.dim() == 2, "workspace must be a 2D tensor of shape [epSize, sizePerRank]");
158163
TORCH_CHECK(workspace.size(0) == epSize, "workspace first dimension must equal epSize");
164+
TORCH_CHECK(epSize > 0 && epSize <= kMaxRanks, "epSize must be in the range (0, ", kMaxRanks, "]");
159165
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
160166

161167
// Initialize workspace to zero
@@ -223,6 +229,7 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
223229
using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch;
224230
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
225231
using tensorrt_llm::kernels::moe_comm::kMaxPayloads;
232+
using tensorrt_llm::kernels::moe_comm::kMaxRanks;
226233

227234
// Validate inputs
228235
CHECK_INPUT(tokenSelectedExperts, torch::kInt32);
@@ -238,6 +245,7 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
238245

239246
int64_t localNumTokens = tokenSelectedExperts.size(0);
240247
TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive");
248+
TORCH_CHECK(epSize > 0 && epSize <= kMaxRanks, "epSize must be in the range (0, ", kMaxRanks, "]");
241249
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
242250
TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]");
243251
TORCH_CHECK(!inputPayloads.empty(), "inputPayloads must not be empty");
@@ -458,6 +466,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
458466
using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams;
459467
using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch;
460468
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
469+
using tensorrt_llm::kernels::moe_comm::kMaxRanks;
461470

462471
// Validate inputs
463472
CHECK_TH_CUDA(payload);
@@ -471,6 +480,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
471480
TORCH_CHECK(reinterpret_cast<uintptr_t>(payload.data_ptr()) % 16 == 0, "payload must be 16-byte aligned");
472481
int64_t elementsPerToken = payload.size(2);
473482
TORCH_CHECK(elementsPerToken > 0, "elementsPerToken must be positive");
483+
TORCH_CHECK(epSize > 0 && epSize <= kMaxRanks, "epSize must be in the range (0, ", kMaxRanks, "]");
474484
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
475485
TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]");
476486

tests/unittest/_torch/multi_gpu/test_moe_a2a_rank_mask.py

Lines changed: 77 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232

3333
import pickle
3434
import sys
35-
import traceback
3635

3736
import cloudpickle
37+
import pynvml
3838
import pytest
3939
import torch
4040
from mpi4py import MPI
@@ -63,6 +63,16 @@ def setup_test():
6363
tllm.logger.set_level("error")
6464

6565

66+
def _skip_if_mnnvl_unsupported() -> None:
67+
try:
68+
MnnvlMemory.initialize()
69+
supports_mnnvl = MnnvlMemory.supports_mnnvl()
70+
except (RuntimeError, pynvml.NVMLError) as exc:
71+
pytest.skip(f"MNNVL not supported on this system: {exc}")
72+
if not supports_mnnvl:
73+
pytest.skip("MNNVL not supported on this system")
74+
75+
6676
def _ep_mask_words(ep_size: int, dead_ranks: set[int]) -> torch.Tensor:
6777
"""Build the uint64[EP_MASK_NUM_WORDS] CPU tensor expected by the C++ op."""
6878
mask_int = ((1 << ep_size) - 1) & ~sum(1 << r for r in dead_ranks)
@@ -162,41 +172,35 @@ def _worker_all_active_matches_no_mask(
162172
):
163173
rank = tllm.mpi_rank()
164174
torch.cuda.set_device(rank)
165-
try:
166-
mapping = Mapping(rank=rank, tp_size=ep_size, moe_ep_size=ep_size, world_size=ep_size)
167-
moe_a2a = MoeAlltoAll(
168-
mapping=mapping,
169-
max_num_tokens=local_num_tokens,
170-
top_k=top_k,
171-
num_slots=num_experts,
172-
workspace_size_per_rank=workspace_size_per_rank,
173-
)
175+
mapping = Mapping(rank=rank, tp_size=ep_size, moe_ep_size=ep_size, world_size=ep_size)
176+
moe_a2a = MoeAlltoAll(
177+
mapping=mapping,
178+
max_num_tokens=local_num_tokens,
179+
top_k=top_k,
180+
num_slots=num_experts,
181+
workspace_size_per_rank=workspace_size_per_rank,
182+
)
174183

175-
# Same RNG seed across both runs => identical inputs.
176-
torch.manual_seed(0xA2A + rank)
177-
token_selected_experts = _generate_token_selected_experts(
178-
local_num_tokens, num_experts, top_k
179-
)
180-
payload = _make_payload(local_num_tokens, hidden_size, rank)
184+
# Same RNG seed across both runs => identical inputs.
185+
torch.manual_seed(0xA2A + rank)
186+
token_selected_experts = _generate_token_selected_experts(local_num_tokens, num_experts, top_k)
187+
payload = _make_payload(local_num_tokens, hidden_size, rank)
181188

182-
out_no_mask, topk_no_mask = _run_dispatch_combine(
183-
moe_a2a, token_selected_experts, payload, local_num_tokens, active_rank_mask=None
184-
)
185-
out_all_active, topk_all_active = _run_dispatch_combine(
186-
moe_a2a,
187-
token_selected_experts,
188-
payload,
189-
local_num_tokens,
190-
active_rank_mask=_ep_mask_words(ep_size, dead_ranks=set()),
191-
)
189+
out_no_mask, topk_no_mask = _run_dispatch_combine(
190+
moe_a2a, token_selected_experts, payload, local_num_tokens, active_rank_mask=None
191+
)
192+
out_all_active, topk_all_active = _run_dispatch_combine(
193+
moe_a2a,
194+
token_selected_experts,
195+
payload,
196+
local_num_tokens,
197+
active_rank_mask=_ep_mask_words(ep_size, dead_ranks=set()),
198+
)
192199

193-
return (
194-
torch.equal(out_no_mask, out_all_active),
195-
torch.equal(topk_no_mask, topk_all_active),
196-
)
197-
except Exception:
198-
traceback.print_exc()
199-
raise
200+
return (
201+
torch.equal(out_no_mask, out_all_active),
202+
torch.equal(topk_no_mask, topk_all_active),
203+
)
200204

201205

202206
# ---------------------------------------------------------------------------
@@ -215,53 +219,47 @@ def _worker_one_rank_masked(
215219
):
216220
rank = tllm.mpi_rank()
217221
torch.cuda.set_device(rank)
218-
try:
219-
mapping = Mapping(rank=rank, tp_size=ep_size, moe_ep_size=ep_size, world_size=ep_size)
220-
# Every rank participates in workspace init (it has MPI barriers internally).
221-
moe_a2a = MoeAlltoAll(
222-
mapping=mapping,
223-
max_num_tokens=local_num_tokens,
224-
top_k=top_k,
225-
num_slots=num_experts,
226-
workspace_size_per_rank=workspace_size_per_rank,
227-
)
222+
mapping = Mapping(rank=rank, tp_size=ep_size, moe_ep_size=ep_size, world_size=ep_size)
223+
# Every rank participates in workspace init (it has MPI barriers internally).
224+
moe_a2a = MoeAlltoAll(
225+
mapping=mapping,
226+
max_num_tokens=local_num_tokens,
227+
top_k=top_k,
228+
num_slots=num_experts,
229+
workspace_size_per_rank=workspace_size_per_rank,
230+
)
228231

229-
if rank == dead_rank:
230-
# Simulate a dead rank: do not call dispatch/combine. Wait at a final
231-
# barrier so the surviving ranks have someone to synchronize with at
232-
# the end of the test. (The kernel itself never observes us because
233-
# the surviving ranks pass a mask with our bit cleared.)
234-
MPI.COMM_WORLD.barrier()
235-
return ("dead", None, None, None)
236-
237-
torch.manual_seed(0xA2A + rank)
238-
token_selected_experts = _generate_token_selected_experts(
239-
local_num_tokens, num_experts, top_k
240-
)
241-
payload = _make_payload(local_num_tokens, hidden_size, rank)
232+
if rank == dead_rank:
233+
# Simulate a dead rank: do not call dispatch/combine. Wait at a final
234+
# barrier so the surviving ranks have someone to synchronize with at
235+
# the end of the test. (The kernel itself never observes us because
236+
# the surviving ranks pass a mask with our bit cleared.)
237+
MPI.COMM_WORLD.barrier()
238+
return ("dead", None, None, None)
242239

243-
# Build mask with dead_rank's bit cleared.
244-
mask = _ep_mask_words(ep_size, dead_ranks={dead_rank})
240+
torch.manual_seed(0xA2A + rank)
241+
token_selected_experts = _generate_token_selected_experts(local_num_tokens, num_experts, top_k)
242+
payload = _make_payload(local_num_tokens, hidden_size, rank)
245243

246-
# Compute the per-token target ranks the way the kernel does so we can
247-
# cross-check the workspace afterwards.
248-
num_experts_per_rank = num_experts // ep_size
249-
expected_target_ranks = (token_selected_experts // num_experts_per_rank).cpu()
244+
# Build mask with dead_rank's bit cleared.
245+
mask = _ep_mask_words(ep_size, dead_ranks={dead_rank})
250246

251-
combined, topk_target_ranks = _run_dispatch_combine(
252-
moe_a2a, token_selected_experts, payload, local_num_tokens, active_rank_mask=mask
253-
)
247+
# Compute the per-token target ranks the way the kernel does so we can
248+
# cross-check the workspace afterwards.
249+
num_experts_per_rank = num_experts // ep_size
250+
expected_target_ranks = (token_selected_experts // num_experts_per_rank).cpu()
254251

255-
MPI.COMM_WORLD.barrier()
256-
return (
257-
"alive",
258-
combined,
259-
topk_target_ranks,
260-
expected_target_ranks,
261-
)
262-
except Exception:
263-
traceback.print_exc()
264-
raise
252+
combined, topk_target_ranks = _run_dispatch_combine(
253+
moe_a2a, token_selected_experts, payload, local_num_tokens, active_rank_mask=mask
254+
)
255+
256+
MPI.COMM_WORLD.barrier()
257+
return (
258+
"alive",
259+
combined,
260+
topk_target_ranks,
261+
expected_target_ranks,
262+
)
265263

266264

267265
# ---------------------------------------------------------------------------
@@ -280,11 +278,7 @@ def _worker_one_rank_masked(
280278
)
281279
def test_all_active_mask_matches_no_mask(mpi_pool_executor, local_num_tokens, top_k):
282280
"""An all-ones active_rank_mask must produce identical output to omitting it."""
283-
try:
284-
MnnvlMemory.initialize()
285-
assert MnnvlMemory.supports_mnnvl()
286-
except Exception:
287-
pytest.skip("MNNVL not supported on this system")
281+
_skip_if_mnnvl_unsupported()
288282

289283
ep_size = mpi_pool_executor.num_workers
290284
if ep_size > torch.cuda.device_count():
@@ -300,7 +294,7 @@ def test_all_active_mask_matches_no_mask(mpi_pool_executor, local_num_tokens, to
300294
results = list(
301295
mpi_pool_executor.map(
302296
_worker_all_active_matches_no_mask,
303-
*zip(*[args] * ep_size),
297+
*zip(*[args] * ep_size, strict=True),
304298
)
305299
)
306300

@@ -329,11 +323,7 @@ def test_one_rank_masked_completes(mpi_pool_executor, dead_rank, local_num_token
329323
* Slots whose expert mapped to a surviving rank are unchanged from what
330324
the contiguous-partition routing rule predicts.
331325
"""
332-
try:
333-
MnnvlMemory.initialize()
334-
assert MnnvlMemory.supports_mnnvl()
335-
except Exception:
336-
pytest.skip("MNNVL not supported on this system")
326+
_skip_if_mnnvl_unsupported()
337327

338328
ep_size = mpi_pool_executor.num_workers
339329
if ep_size > torch.cuda.device_count():
@@ -358,7 +348,7 @@ def test_one_rank_masked_completes(mpi_pool_executor, dead_rank, local_num_token
358348
results = list(
359349
mpi_pool_executor.map(
360350
_worker_one_rank_masked,
361-
*zip(*[args] * ep_size),
351+
*zip(*[args] * ep_size, strict=True),
362352
)
363353
)
364354

0 commit comments

Comments
 (0)