3232
3333import pickle
3434import sys
35- import traceback
3635
3736import cloudpickle
37+ import pynvml
3838import pytest
3939import torch
4040from mpi4py import MPI
@@ -63,6 +63,16 @@ def setup_test():
6363 tllm .logger .set_level ("error" )
6464
6565
66+ def _skip_if_mnnvl_unsupported () -> None :
67+ try :
68+ MnnvlMemory .initialize ()
69+ supports_mnnvl = MnnvlMemory .supports_mnnvl ()
70+ except (RuntimeError , pynvml .NVMLError ) as exc :
71+ pytest .skip (f"MNNVL not supported on this system: { exc } " )
72+ if not supports_mnnvl :
73+ pytest .skip ("MNNVL not supported on this system" )
74+
75+
6676def _ep_mask_words (ep_size : int , dead_ranks : set [int ]) -> torch .Tensor :
6777 """Build the uint64[EP_MASK_NUM_WORDS] CPU tensor expected by the C++ op."""
6878 mask_int = ((1 << ep_size ) - 1 ) & ~ sum (1 << r for r in dead_ranks )
@@ -162,41 +172,35 @@ def _worker_all_active_matches_no_mask(
162172):
163173 rank = tllm .mpi_rank ()
164174 torch .cuda .set_device (rank )
165- try :
166- mapping = Mapping (rank = rank , tp_size = ep_size , moe_ep_size = ep_size , world_size = ep_size )
167- moe_a2a = MoeAlltoAll (
168- mapping = mapping ,
169- max_num_tokens = local_num_tokens ,
170- top_k = top_k ,
171- num_slots = num_experts ,
172- workspace_size_per_rank = workspace_size_per_rank ,
173- )
175+ mapping = Mapping (rank = rank , tp_size = ep_size , moe_ep_size = ep_size , world_size = ep_size )
176+ moe_a2a = MoeAlltoAll (
177+ mapping = mapping ,
178+ max_num_tokens = local_num_tokens ,
179+ top_k = top_k ,
180+ num_slots = num_experts ,
181+ workspace_size_per_rank = workspace_size_per_rank ,
182+ )
174183
175- # Same RNG seed across both runs => identical inputs.
176- torch .manual_seed (0xA2A + rank )
177- token_selected_experts = _generate_token_selected_experts (
178- local_num_tokens , num_experts , top_k
179- )
180- payload = _make_payload (local_num_tokens , hidden_size , rank )
184+ # Same RNG seed across both runs => identical inputs.
185+ torch .manual_seed (0xA2A + rank )
186+ token_selected_experts = _generate_token_selected_experts (local_num_tokens , num_experts , top_k )
187+ payload = _make_payload (local_num_tokens , hidden_size , rank )
181188
182- out_no_mask , topk_no_mask = _run_dispatch_combine (
183- moe_a2a , token_selected_experts , payload , local_num_tokens , active_rank_mask = None
184- )
185- out_all_active , topk_all_active = _run_dispatch_combine (
186- moe_a2a ,
187- token_selected_experts ,
188- payload ,
189- local_num_tokens ,
190- active_rank_mask = _ep_mask_words (ep_size , dead_ranks = set ()),
191- )
189+ out_no_mask , topk_no_mask = _run_dispatch_combine (
190+ moe_a2a , token_selected_experts , payload , local_num_tokens , active_rank_mask = None
191+ )
192+ out_all_active , topk_all_active = _run_dispatch_combine (
193+ moe_a2a ,
194+ token_selected_experts ,
195+ payload ,
196+ local_num_tokens ,
197+ active_rank_mask = _ep_mask_words (ep_size , dead_ranks = set ()),
198+ )
192199
193- return (
194- torch .equal (out_no_mask , out_all_active ),
195- torch .equal (topk_no_mask , topk_all_active ),
196- )
197- except Exception :
198- traceback .print_exc ()
199- raise
200+ return (
201+ torch .equal (out_no_mask , out_all_active ),
202+ torch .equal (topk_no_mask , topk_all_active ),
203+ )
200204
201205
202206# ---------------------------------------------------------------------------
@@ -215,53 +219,47 @@ def _worker_one_rank_masked(
215219):
216220 rank = tllm .mpi_rank ()
217221 torch .cuda .set_device (rank )
218- try :
219- mapping = Mapping (rank = rank , tp_size = ep_size , moe_ep_size = ep_size , world_size = ep_size )
220- # Every rank participates in workspace init (it has MPI barriers internally).
221- moe_a2a = MoeAlltoAll (
222- mapping = mapping ,
223- max_num_tokens = local_num_tokens ,
224- top_k = top_k ,
225- num_slots = num_experts ,
226- workspace_size_per_rank = workspace_size_per_rank ,
227- )
222+ mapping = Mapping (rank = rank , tp_size = ep_size , moe_ep_size = ep_size , world_size = ep_size )
223+ # Every rank participates in workspace init (it has MPI barriers internally).
224+ moe_a2a = MoeAlltoAll (
225+ mapping = mapping ,
226+ max_num_tokens = local_num_tokens ,
227+ top_k = top_k ,
228+ num_slots = num_experts ,
229+ workspace_size_per_rank = workspace_size_per_rank ,
230+ )
228231
229- if rank == dead_rank :
230- # Simulate a dead rank: do not call dispatch/combine. Wait at a final
231- # barrier so the surviving ranks have someone to synchronize with at
232- # the end of the test. (The kernel itself never observes us because
233- # the surviving ranks pass a mask with our bit cleared.)
234- MPI .COMM_WORLD .barrier ()
235- return ("dead" , None , None , None )
236-
237- torch .manual_seed (0xA2A + rank )
238- token_selected_experts = _generate_token_selected_experts (
239- local_num_tokens , num_experts , top_k
240- )
241- payload = _make_payload (local_num_tokens , hidden_size , rank )
232+ if rank == dead_rank :
233+ # Simulate a dead rank: do not call dispatch/combine. Wait at a final
234+ # barrier so the surviving ranks have someone to synchronize with at
235+ # the end of the test. (The kernel itself never observes us because
236+ # the surviving ranks pass a mask with our bit cleared.)
237+ MPI .COMM_WORLD .barrier ()
238+ return ("dead" , None , None , None )
242239
243- # Build mask with dead_rank's bit cleared.
244- mask = _ep_mask_words (ep_size , dead_ranks = {dead_rank })
240+ torch .manual_seed (0xA2A + rank )
241+ token_selected_experts = _generate_token_selected_experts (local_num_tokens , num_experts , top_k )
242+ payload = _make_payload (local_num_tokens , hidden_size , rank )
245243
246- # Compute the per-token target ranks the way the kernel does so we can
247- # cross-check the workspace afterwards.
248- num_experts_per_rank = num_experts // ep_size
249- expected_target_ranks = (token_selected_experts // num_experts_per_rank ).cpu ()
244+ # Build mask with dead_rank's bit cleared.
245+ mask = _ep_mask_words (ep_size , dead_ranks = {dead_rank })
250246
251- combined , topk_target_ranks = _run_dispatch_combine (
252- moe_a2a , token_selected_experts , payload , local_num_tokens , active_rank_mask = mask
253- )
247+ # Compute the per-token target ranks the way the kernel does so we can
248+ # cross-check the workspace afterwards.
249+ num_experts_per_rank = num_experts // ep_size
250+ expected_target_ranks = (token_selected_experts // num_experts_per_rank ).cpu ()
254251
255- MPI .COMM_WORLD .barrier ()
256- return (
257- "alive" ,
258- combined ,
259- topk_target_ranks ,
260- expected_target_ranks ,
261- )
262- except Exception :
263- traceback .print_exc ()
264- raise
252+ combined , topk_target_ranks = _run_dispatch_combine (
253+ moe_a2a , token_selected_experts , payload , local_num_tokens , active_rank_mask = mask
254+ )
255+
256+ MPI .COMM_WORLD .barrier ()
257+ return (
258+ "alive" ,
259+ combined ,
260+ topk_target_ranks ,
261+ expected_target_ranks ,
262+ )
265263
266264
267265# ---------------------------------------------------------------------------
@@ -280,11 +278,7 @@ def _worker_one_rank_masked(
280278)
281279def test_all_active_mask_matches_no_mask (mpi_pool_executor , local_num_tokens , top_k ):
282280 """An all-ones active_rank_mask must produce identical output to omitting it."""
283- try :
284- MnnvlMemory .initialize ()
285- assert MnnvlMemory .supports_mnnvl ()
286- except Exception :
287- pytest .skip ("MNNVL not supported on this system" )
281+ _skip_if_mnnvl_unsupported ()
288282
289283 ep_size = mpi_pool_executor .num_workers
290284 if ep_size > torch .cuda .device_count ():
@@ -300,7 +294,7 @@ def test_all_active_mask_matches_no_mask(mpi_pool_executor, local_num_tokens, to
300294 results = list (
301295 mpi_pool_executor .map (
302296 _worker_all_active_matches_no_mask ,
303- * zip (* [args ] * ep_size ),
297+ * zip (* [args ] * ep_size , strict = True ),
304298 )
305299 )
306300
@@ -329,11 +323,7 @@ def test_one_rank_masked_completes(mpi_pool_executor, dead_rank, local_num_token
329323 * Slots whose expert mapped to a surviving rank are unchanged from what
330324 the contiguous-partition routing rule predicts.
331325 """
332- try :
333- MnnvlMemory .initialize ()
334- assert MnnvlMemory .supports_mnnvl ()
335- except Exception :
336- pytest .skip ("MNNVL not supported on this system" )
326+ _skip_if_mnnvl_unsupported ()
337327
338328 ep_size = mpi_pool_executor .num_workers
339329 if ep_size > torch .cuda .device_count ():
@@ -358,7 +348,7 @@ def test_one_rank_masked_completes(mpi_pool_executor, dead_rank, local_num_token
358348 results = list (
359349 mpi_pool_executor .map (
360350 _worker_one_rank_masked ,
361- * zip (* [args ] * ep_size ),
351+ * zip (* [args ] * ep_size , strict = True ),
362352 )
363353 )
364354
0 commit comments