From 0a696d5f033d6c0731dfc0af360bb76422484b85 Mon Sep 17 00:00:00 2001 From: amd-ruitang3 Date: Tue, 26 May 2026 01:53:33 -0500 Subject: [PATCH 1/3] [TRITON] moe routing support expert_map for expert parallelism --- .../fusions/fused_routing_from_topk.py | 34 ++++++++- .../triton/fusions/fused_routing_from_topk.py | 19 ++++- .../fusions/test_fused_routing_from_topk.py | 69 ++++++++++++++++++- 3 files changed, 118 insertions(+), 4 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py b/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py index 0dcd54c6a5..597b6f969c 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py @@ -15,6 +15,7 @@ "_fused_routing_from_topk_hist_kernel", [ "E", + "HAS_EXPERT_MAP", "BLOCK_NK", "BLOCK_E", ], @@ -31,6 +32,7 @@ _fused_routing_from_topk_place_kernel_repr = make_kernel_repr( "_fused_routing_from_topk_place_kernel", [ + "HAS_EXPERT_MAP", "BLOCK_NK", ], ) @@ -40,11 +42,14 @@ def _fused_routing_from_topk_hist_kernel( # inputs topk_ids_ptr, # [NK] int32 — flattened topk_ids + expert_map_ptr, # [N_EXPERTS_GLOBAL] int32 or identity map fallback + expert_map_numel, # runtime int — bounds for expert_map_ptr # outputs hist_ptr, # [E] int32 — tokens-per-expert histogram # shapes NK, # runtime int — actual valid item count (≤ BLOCK_NK) E: tl.constexpr, + HAS_EXPERT_MAP: tl.constexpr, BLOCK_NK: tl.constexpr, # padded to next pow2 of NK BLOCK_E: tl.constexpr, # padded to next pow2 of E (tl.histogram needs pow2) ): @@ -62,7 +67,18 @@ def _fused_routing_from_topk_hist_kernel( # Clamp the offset for masked-out lanes to 0 so the pointer arithmetic # below stays within the allocated buffers. safe_item = tl.where(item_mask, item_offs, 0) - expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + if HAS_EXPERT_MAP: + map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel) + safe_global_expt = tl.where(map_mask, global_expt, 0) + local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to( + tl.int32 + ) + # Match reference semantics: invalid experts are redirected to bucket 0 + # and later zeroed in gate_scal. + expt = tl.where(local_expt >= 0, local_expt, 0) + else: + expt = global_expt hist = tl.histogram(expt, BLOCK_E, mask=item_mask) @@ -99,6 +115,8 @@ def _fused_routing_from_topk_place_kernel( # inputs topk_ids_ptr, # [NK] int32 — flattened topk_ids topk_weights_ptr, # [NK] (any float dtype) — flattened topk_weights + expert_map_ptr, # [N_EXPERTS_GLOBAL] int32 or identity map fallback + expert_map_numel, # runtime int — bounds for expert_map_ptr offset_ptr, # [E] int32 — exclusive prefix sums from the offset kernel # outputs topk_indx_ptr, # [NK] int32 — output gather_indx.src_indx @@ -106,6 +124,7 @@ def _fused_routing_from_topk_place_kernel( gate_scal_ptr, # [NK] same dtype as topk_weights # shapes NK, # runtime int — actual valid item count (≤ BLOCK_NK) + HAS_EXPERT_MAP: tl.constexpr, BLOCK_NK: tl.constexpr, # padded to next pow2 of NK ): """Phase C: place items. @@ -122,8 +141,19 @@ def _fused_routing_from_topk_place_kernel( item_offs = tl.arange(0, BLOCK_NK) item_mask = item_offs < NK safe_item = tl.where(item_mask, item_offs, 0) - expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) weights = tl.load(topk_weights_ptr + safe_item, mask=item_mask, other=0.0) + if HAS_EXPERT_MAP: + map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel) + safe_global_expt = tl.where(map_mask, global_expt, 0) + local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to( + tl.int32 + ) + invalid = local_expt < 0 + expt = tl.where(invalid, 0, local_expt) + weights = tl.where(invalid, 0.0, weights) + else: + expt = global_expt pos = tl.atomic_add(offset_ptr + expt, 1, mask=item_mask) diff --git a/aiter/ops/triton/fusions/fused_routing_from_topk.py b/aiter/ops/triton/fusions/fused_routing_from_topk.py index 83c746c693..4ff5aabb8f 100644 --- a/aiter/ops/triton/fusions/fused_routing_from_topk.py +++ b/aiter/ops/triton/fusions/fused_routing_from_topk.py @@ -4,7 +4,7 @@ # Fused replacement for the multi-kernel "topk → routing data" chain that # bridges FusedMoE.select_experts to triton_kernels.matmul_ogs. See the # accompanying _triton_kernels/fused_routing_from_topk.py for the kernel. -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -30,6 +30,7 @@ def fused_routing_from_topk( topk_weights: torch.Tensor, topk_ids: torch.Tensor, n_expts_tot: int, + expert_map: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Sort (token, slot) pairs by their expert id via a single Triton kernel. @@ -44,6 +45,10 @@ def fused_routing_from_topk( topk_ids: ``[n_tokens, n_expts_act]`` selected expert ids; values in ``[0, n_expts_tot)``. n_expts_tot: Total number of routed experts (= ``E``). + expert_map: Optional global→local expert map. When provided, + ``topk_ids`` are treated as global ids and remapped inside fused + kernels. Entries mapped to ``< 0`` are masked to zero weight and + redirected to local expert ``0`` for routing safety. Returns: Tuple ``(hist, topk_indx, gate_indx, gate_scal)``: @@ -94,6 +99,12 @@ def fused_routing_from_topk( # are no-ops. topk_ids_flat = topk_ids.contiguous().reshape(-1).to(torch.int32) topk_weights_flat = topk_weights.contiguous().reshape(-1) + expert_map_numel = 0 + expert_map_flat = topk_ids_flat + has_expert_map = expert_map is not None + if has_expert_map: + expert_map_flat = expert_map.contiguous().reshape(-1).to(torch.int32) + expert_map_numel = int(expert_map_flat.numel()) topk_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device) gate_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device) @@ -109,9 +120,12 @@ def fused_routing_from_topk( # single wave, matching the CTA-local design of the original kernel. _fused_routing_from_topk_hist_kernel[(1,)]( topk_ids_flat, + expert_map_flat, + expert_map_numel, hist, n_gates_pad, E=n_expts_tot, + HAS_EXPERT_MAP=has_expert_map, BLOCK_NK=BLOCK_NK, BLOCK_E=BLOCK_E, num_warps=1, @@ -132,11 +146,14 @@ def fused_routing_from_topk( _fused_routing_from_topk_place_kernel[(1,)]( topk_ids_flat, topk_weights_flat, + expert_map_flat, + expert_map_numel, offset_scratch, topk_indx, gate_indx, gate_scal, n_gates_pad, + HAS_EXPERT_MAP=has_expert_map, BLOCK_NK=BLOCK_NK, num_warps=1, ) diff --git a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py index ea2c1552ef..05dc705b49 100644 --- a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py +++ b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py @@ -28,7 +28,7 @@ # Returns ``(hist, topk_indx, gate_indx, gate_scal)`` for direct comparison # against the fused kernel. # --------------------------------------------------------------------------- -def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot): +def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot, expert_map=None): """Multi-kernel torch reference for fused_routing_from_topk. Per-row sort of ``topk_ids`` followed by a stable global argsort by @@ -37,6 +37,12 @@ def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot): version), unlike the fused kernel which is non-deterministic at intra-expert ordering. """ + if expert_map is not None: + local_ids = expert_map[topk_ids.long()] + invalid = local_ids < 0 + topk_weights = topk_weights.masked_fill(invalid, 0.0) + topk_ids = local_ids.masked_fill(invalid, 0).to(torch.int32) + expt_indx_sorted, sort_indices = torch.sort(topk_ids.int(), dim=1) expt_scal_sorted = torch.gather(topk_weights, 1, sort_indices.long()) @@ -251,3 +257,64 @@ def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype): test_hist, test_topk_indx, test_gate_scal, n_expts_act ) _compare_buckets(ref_buckets, test_buckets) + + +@pytest.mark.parametrize( + "n_tokens, n_expts_act, n_expts_tot,n_expts_global", + [ + (16, 6, 64, 256), + (64, 6, 128, 256), + ], +) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_fused_routing_from_topk_with_expert_map( + n_tokens, n_expts_act, n_expts_tot, n_expts_global, dtype +): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + torch.manual_seed(0) + topk_ids, topk_weights = _make_inputs( + n_tokens, n_expts_act, n_expts_global, dtype, DEVICE, seed=0 + ) + + expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE) + expert_map[: n_expts_tot // 2] = torch.arange( + n_expts_tot // 2, dtype=torch.int32, device=DEVICE + ) + + ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference( + topk_weights, topk_ids, n_expts_tot, expert_map=expert_map + ) + _check_routing_invariants( + ref_hist, + ref_topk_indx, + ref_gate_indx, + ref_gate_scal, + topk_ids, + n_expts_tot, + bucket_unsorted_layout=False, + ) + + test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk( + topk_weights, topk_ids, n_expts_tot, expert_map=expert_map + ) + _check_routing_invariants( + test_hist, + test_topk_indx, + test_gate_indx, + test_gate_scal, + topk_ids, + n_expts_tot, + bucket_unsorted_layout=False, + ) + + assert torch.equal(ref_hist, test_hist), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" + + # Intra-expert ordering can differ between fused and reference, + # especially in expert-0 bucket where invalid experts are redirected. + # Compare zeroed-weight cardinality instead of elementwise positions. + ref_zero_count = int((ref_gate_scal == 0).sum().item()) + test_zero_count = int((test_gate_scal == 0).sum().item()) + assert ( + ref_zero_count == test_zero_count + ), f"zero-masked count mismatch: ref={ref_zero_count}, fused={test_zero_count}" From e255e6a82d6edfc8af61fc925b416a4f1ab45c75 Mon Sep 17 00:00:00 2001 From: amd-ruitang3 Date: Tue, 26 May 2026 01:58:45 -0500 Subject: [PATCH 2/3] black format --- .../fusions/fused_routing_from_topk.py | 20 +++++++++++-------- .../fusions/test_fused_routing_from_topk.py | 4 +++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py b/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py index 597b6f969c..0c174005d8 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py @@ -67,13 +67,15 @@ def _fused_routing_from_topk_hist_kernel( # Clamp the offset for masked-out lanes to 0 so the pointer arithmetic # below stays within the allocated buffers. safe_item = tl.where(item_mask, item_offs, 0) - global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to( + tl.int32 + ) if HAS_EXPERT_MAP: map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel) safe_global_expt = tl.where(map_mask, global_expt, 0) - local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to( - tl.int32 - ) + local_expt = tl.load( + expert_map_ptr + safe_global_expt, mask=map_mask, other=-1 + ).to(tl.int32) # Match reference semantics: invalid experts are redirected to bucket 0 # and later zeroed in gate_scal. expt = tl.where(local_expt >= 0, local_expt, 0) @@ -141,14 +143,16 @@ def _fused_routing_from_topk_place_kernel( item_offs = tl.arange(0, BLOCK_NK) item_mask = item_offs < NK safe_item = tl.where(item_mask, item_offs, 0) - global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to( + tl.int32 + ) weights = tl.load(topk_weights_ptr + safe_item, mask=item_mask, other=0.0) if HAS_EXPERT_MAP: map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel) safe_global_expt = tl.where(map_mask, global_expt, 0) - local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to( - tl.int32 - ) + local_expt = tl.load( + expert_map_ptr + safe_global_expt, mask=map_mask, other=-1 + ).to(tl.int32) invalid = local_expt < 0 expt = tl.where(invalid, 0, local_expt) weights = tl.where(invalid, 0.0, weights) diff --git a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py index 05dc705b49..adc36b14cc 100644 --- a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py +++ b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py @@ -308,7 +308,9 @@ def test_fused_routing_from_topk_with_expert_map( bucket_unsorted_layout=False, ) - assert torch.equal(ref_hist, test_hist), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" + assert torch.equal( + ref_hist, test_hist + ), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" # Intra-expert ordering can differ between fused and reference, # especially in expert-0 bucket where invalid experts are redirected. From c651c8a658d4bb934d1850c31088016f9c74b0f4 Mon Sep 17 00:00:00 2001 From: amd-ruitang3 Date: Tue, 26 May 2026 11:16:23 -0500 Subject: [PATCH 3/3] update_with_shao-chun_comments --- .../triton/fusions/fused_routing_from_topk.py | 22 ++- .../fusions/test_fused_routing_from_topk.py | 137 +++++++----------- 2 files changed, 65 insertions(+), 94 deletions(-) diff --git a/aiter/ops/triton/fusions/fused_routing_from_topk.py b/aiter/ops/triton/fusions/fused_routing_from_topk.py index 4ff5aabb8f..e144f82f8a 100644 --- a/aiter/ops/triton/fusions/fused_routing_from_topk.py +++ b/aiter/ops/triton/fusions/fused_routing_from_topk.py @@ -42,13 +42,15 @@ def fused_routing_from_topk( Args: topk_weights: ``[n_tokens, n_expts_act]`` per-token routing weights. + Must be contiguous. topk_ids: ``[n_tokens, n_expts_act]`` selected expert ids; values - in ``[0, n_expts_tot)``. + in ``[0, n_expts_tot)``. Must be contiguous int32. n_expts_tot: Total number of routed experts (= ``E``). expert_map: Optional global→local expert map. When provided, ``topk_ids`` are treated as global ids and remapped inside fused kernels. Entries mapped to ``< 0`` are masked to zero weight and - redirected to local expert ``0`` for routing safety. + redirected to local expert ``0`` for routing safety. Must be + contiguous int32 when provided. Returns: Tuple ``(hist, topk_indx, gate_indx, gate_scal)``: @@ -94,16 +96,20 @@ def fused_routing_from_topk( device = topk_weights.device weights_dtype = topk_weights.dtype - # Triton kernel needs flat int32 inputs. .reshape on a contiguous tensor - # is a view; .contiguous() / .to(int32) on already-canonical tensors - # are no-ops. - topk_ids_flat = topk_ids.contiguous().reshape(-1).to(torch.int32) - topk_weights_flat = topk_weights.contiguous().reshape(-1) + assert ( + topk_ids.is_contiguous() and topk_ids.dtype == torch.int32 + ), "topk_ids must be contiguous int32" + assert topk_weights.is_contiguous(), "topk_weights must be contiguous" + topk_ids_flat = topk_ids.reshape(-1) + topk_weights_flat = topk_weights.reshape(-1) expert_map_numel = 0 expert_map_flat = topk_ids_flat has_expert_map = expert_map is not None if has_expert_map: - expert_map_flat = expert_map.contiguous().reshape(-1).to(torch.int32) + assert ( + expert_map.is_contiguous() and expert_map.dtype == torch.int32 + ), "expert_map must be contiguous int32" + expert_map_flat = expert_map.reshape(-1) expert_map_numel = int(expert_map_flat.numel()) topk_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device) diff --git a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py index adc36b14cc..098b889a43 100644 --- a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py +++ b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py @@ -189,98 +189,48 @@ def _compare_buckets(ref_buckets, test_buckets, atol=1e-6): # tests # --------------------------------------------------------------------------- @pytest.mark.parametrize( - "n_tokens, n_expts_act, n_expts_tot", + "n_tokens, n_expts_act, n_expts_tot, n_expts_global", [ - # V4-Flash decode shapes (E=256, K=6). - (1, 6, 256), - (16, 6, 256), - (64, 6, 256), - (256, 6, 256), + # V4-Flash decode shapes (E=256, K=6). n_expts_global ignored when + # has_expert_map=False. + (1, 6, 256, 256), + (16, 6, 256, 256), + (64, 6, 256, 256), + (256, 6, 256, 256), # Generic decode shapes used by other MoE configs. - (1, 8, 384), - (4, 8, 384), - (64, 8, 384), - (256, 8, 384), + (1, 8, 384, 384), + (4, 8, 384, 384), + (64, 8, 384, 384), + (256, 8, 384, 384), # Edge: small E. - (32, 4, 16), + (32, 4, 16, 16), # Boundary: NK at the kernel's MAX_NK = 4096. - (512, 8, 384), - ], -) -@pytest.mark.parametrize("dtype", [torch.float32]) -def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype): - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - torch.manual_seed(0) - topk_ids, topk_weights = _make_inputs( - n_tokens, n_expts_act, n_expts_tot, dtype, DEVICE, seed=0 - ) - - ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference( - topk_weights, topk_ids, n_expts_tot - ) - _check_routing_invariants( - ref_hist, - ref_topk_indx, - ref_gate_indx, - ref_gate_scal, - topk_ids, - n_expts_tot, - bucket_unsorted_layout=False, # ref uses per-row-sorted layout - ) - ground_buckets = _ground_truth_buckets(topk_ids, topk_weights) - ref_buckets = _per_expert_triples( - ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act - ) - _compare_buckets(ground_buckets, ref_buckets) - - test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk( - topk_weights, topk_ids, n_expts_tot - ) - _check_routing_invariants( - test_hist, - test_topk_indx, - test_gate_indx, - test_gate_scal, - topk_ids, - n_expts_tot, - bucket_unsorted_layout=True, # fused uses unsorted topk_ids layout - ) - - # hist must match the reference exactly. - assert torch.equal( - ref_hist, test_hist - ), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" - - # Per-expert (token, weight) multisets match the reference. - test_buckets = _per_expert_triples( - test_hist, test_topk_indx, test_gate_scal, n_expts_act - ) - _compare_buckets(ref_buckets, test_buckets) - - -@pytest.mark.parametrize( - "n_tokens, n_expts_act, n_expts_tot,n_expts_global", - [ + (512, 8, 384, 384), + # Expert-parallel shapes: n_expts_global > n_expts_tot, requires map. (16, 6, 64, 256), (64, 6, 128, 256), ], ) +@pytest.mark.parametrize("has_expert_map", [False, True]) @pytest.mark.parametrize("dtype", [torch.float32]) -def test_fused_routing_from_topk_with_expert_map( - n_tokens, n_expts_act, n_expts_tot, n_expts_global, dtype +def test_fused_routing_from_topk( + n_tokens, n_expts_act, n_expts_tot, n_expts_global, has_expert_map, dtype ): if not torch.cuda.is_available(): pytest.skip("CUDA not available") torch.manual_seed(0) + + id_range = n_expts_global if has_expert_map else n_expts_tot topk_ids, topk_weights = _make_inputs( - n_tokens, n_expts_act, n_expts_global, dtype, DEVICE, seed=0 + n_tokens, n_expts_act, id_range, dtype, DEVICE, seed=0 ) - expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE) - expert_map[: n_expts_tot // 2] = torch.arange( - n_expts_tot // 2, dtype=torch.int32, device=DEVICE - ) + expert_map = None + if has_expert_map: + expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE) + expert_map[: n_expts_tot // 2] = torch.arange( + n_expts_tot // 2, dtype=torch.int32, device=DEVICE + ) ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference( topk_weights, topk_ids, n_expts_tot, expert_map=expert_map @@ -292,7 +242,7 @@ def test_fused_routing_from_topk_with_expert_map( ref_gate_scal, topk_ids, n_expts_tot, - bucket_unsorted_layout=False, + bucket_unsorted_layout=False, # ref uses per-row-sorted layout ) test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk( @@ -305,18 +255,33 @@ def test_fused_routing_from_topk_with_expert_map( test_gate_scal, topk_ids, n_expts_tot, - bucket_unsorted_layout=False, + bucket_unsorted_layout=not has_expert_map, ) + # hist must match the reference exactly. assert torch.equal( ref_hist, test_hist ), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" - # Intra-expert ordering can differ between fused and reference, - # especially in expert-0 bucket where invalid experts are redirected. - # Compare zeroed-weight cardinality instead of elementwise positions. - ref_zero_count = int((ref_gate_scal == 0).sum().item()) - test_zero_count = int((test_gate_scal == 0).sum().item()) - assert ( - ref_zero_count == test_zero_count - ), f"zero-masked count mismatch: ref={ref_zero_count}, fused={test_zero_count}" + if has_expert_map: + # Intra-expert ordering can differ between fused and reference, + # especially in expert-0 bucket where invalid experts are redirected. + # Compare zeroed-weight cardinality instead of elementwise positions. + ref_zero_count = int((ref_gate_scal == 0).sum().item()) + test_zero_count = int((test_gate_scal == 0).sum().item()) + assert ref_zero_count == test_zero_count, ( + f"zero-masked count mismatch: " + f"ref={ref_zero_count}, fused={test_zero_count}" + ) + else: + ground_buckets = _ground_truth_buckets(topk_ids, topk_weights) + ref_buckets = _per_expert_triples( + ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act + ) + _compare_buckets(ground_buckets, ref_buckets) + + # Per-expert (token, weight) multisets match the reference. + test_buckets = _per_expert_triples( + test_hist, test_topk_indx, test_gate_scal, n_expts_act + ) + _compare_buckets(ref_buckets, test_buckets)