diff --git a/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py b/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py index 0dcd54c6a5..0c174005d8 100644 --- a/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py +++ b/aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py @@ -15,6 +15,7 @@ "_fused_routing_from_topk_hist_kernel", [ "E", + "HAS_EXPERT_MAP", "BLOCK_NK", "BLOCK_E", ], @@ -31,6 +32,7 @@ _fused_routing_from_topk_place_kernel_repr = make_kernel_repr( "_fused_routing_from_topk_place_kernel", [ + "HAS_EXPERT_MAP", "BLOCK_NK", ], ) @@ -40,11 +42,14 @@ def _fused_routing_from_topk_hist_kernel( # inputs topk_ids_ptr, # [NK] int32 — flattened topk_ids + expert_map_ptr, # [N_EXPERTS_GLOBAL] int32 or identity map fallback + expert_map_numel, # runtime int — bounds for expert_map_ptr # outputs hist_ptr, # [E] int32 — tokens-per-expert histogram # shapes NK, # runtime int — actual valid item count (≤ BLOCK_NK) E: tl.constexpr, + HAS_EXPERT_MAP: tl.constexpr, BLOCK_NK: tl.constexpr, # padded to next pow2 of NK BLOCK_E: tl.constexpr, # padded to next pow2 of E (tl.histogram needs pow2) ): @@ -62,7 +67,20 @@ def _fused_routing_from_topk_hist_kernel( # Clamp the offset for masked-out lanes to 0 so the pointer arithmetic # below stays within the allocated buffers. safe_item = tl.where(item_mask, item_offs, 0) - expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to( + tl.int32 + ) + if HAS_EXPERT_MAP: + map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel) + safe_global_expt = tl.where(map_mask, global_expt, 0) + local_expt = tl.load( + expert_map_ptr + safe_global_expt, mask=map_mask, other=-1 + ).to(tl.int32) + # Match reference semantics: invalid experts are redirected to bucket 0 + # and later zeroed in gate_scal. + expt = tl.where(local_expt >= 0, local_expt, 0) + else: + expt = global_expt hist = tl.histogram(expt, BLOCK_E, mask=item_mask) @@ -99,6 +117,8 @@ def _fused_routing_from_topk_place_kernel( # inputs topk_ids_ptr, # [NK] int32 — flattened topk_ids topk_weights_ptr, # [NK] (any float dtype) — flattened topk_weights + expert_map_ptr, # [N_EXPERTS_GLOBAL] int32 or identity map fallback + expert_map_numel, # runtime int — bounds for expert_map_ptr offset_ptr, # [E] int32 — exclusive prefix sums from the offset kernel # outputs topk_indx_ptr, # [NK] int32 — output gather_indx.src_indx @@ -106,6 +126,7 @@ def _fused_routing_from_topk_place_kernel( gate_scal_ptr, # [NK] same dtype as topk_weights # shapes NK, # runtime int — actual valid item count (≤ BLOCK_NK) + HAS_EXPERT_MAP: tl.constexpr, BLOCK_NK: tl.constexpr, # padded to next pow2 of NK ): """Phase C: place items. @@ -122,8 +143,21 @@ def _fused_routing_from_topk_place_kernel( item_offs = tl.arange(0, BLOCK_NK) item_mask = item_offs < NK safe_item = tl.where(item_mask, item_offs, 0) - expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32) + global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to( + tl.int32 + ) weights = tl.load(topk_weights_ptr + safe_item, mask=item_mask, other=0.0) + if HAS_EXPERT_MAP: + map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel) + safe_global_expt = tl.where(map_mask, global_expt, 0) + local_expt = tl.load( + expert_map_ptr + safe_global_expt, mask=map_mask, other=-1 + ).to(tl.int32) + invalid = local_expt < 0 + expt = tl.where(invalid, 0, local_expt) + weights = tl.where(invalid, 0.0, weights) + else: + expt = global_expt pos = tl.atomic_add(offset_ptr + expt, 1, mask=item_mask) diff --git a/aiter/ops/triton/fusions/fused_routing_from_topk.py b/aiter/ops/triton/fusions/fused_routing_from_topk.py index 83c746c693..e144f82f8a 100644 --- a/aiter/ops/triton/fusions/fused_routing_from_topk.py +++ b/aiter/ops/triton/fusions/fused_routing_from_topk.py @@ -4,7 +4,7 @@ # Fused replacement for the multi-kernel "topk → routing data" chain that # bridges FusedMoE.select_experts to triton_kernels.matmul_ogs. See the # accompanying _triton_kernels/fused_routing_from_topk.py for the kernel. -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -30,6 +30,7 @@ def fused_routing_from_topk( topk_weights: torch.Tensor, topk_ids: torch.Tensor, n_expts_tot: int, + expert_map: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Sort (token, slot) pairs by their expert id via a single Triton kernel. @@ -41,9 +42,15 @@ def fused_routing_from_topk( Args: topk_weights: ``[n_tokens, n_expts_act]`` per-token routing weights. + Must be contiguous. topk_ids: ``[n_tokens, n_expts_act]`` selected expert ids; values - in ``[0, n_expts_tot)``. + in ``[0, n_expts_tot)``. Must be contiguous int32. n_expts_tot: Total number of routed experts (= ``E``). + expert_map: Optional global→local expert map. When provided, + ``topk_ids`` are treated as global ids and remapped inside fused + kernels. Entries mapped to ``< 0`` are masked to zero weight and + redirected to local expert ``0`` for routing safety. Must be + contiguous int32 when provided. Returns: Tuple ``(hist, topk_indx, gate_indx, gate_scal)``: @@ -89,11 +96,21 @@ def fused_routing_from_topk( device = topk_weights.device weights_dtype = topk_weights.dtype - # Triton kernel needs flat int32 inputs. .reshape on a contiguous tensor - # is a view; .contiguous() / .to(int32) on already-canonical tensors - # are no-ops. - topk_ids_flat = topk_ids.contiguous().reshape(-1).to(torch.int32) - topk_weights_flat = topk_weights.contiguous().reshape(-1) + assert ( + topk_ids.is_contiguous() and topk_ids.dtype == torch.int32 + ), "topk_ids must be contiguous int32" + assert topk_weights.is_contiguous(), "topk_weights must be contiguous" + topk_ids_flat = topk_ids.reshape(-1) + topk_weights_flat = topk_weights.reshape(-1) + expert_map_numel = 0 + expert_map_flat = topk_ids_flat + has_expert_map = expert_map is not None + if has_expert_map: + assert ( + expert_map.is_contiguous() and expert_map.dtype == torch.int32 + ), "expert_map must be contiguous int32" + expert_map_flat = expert_map.reshape(-1) + expert_map_numel = int(expert_map_flat.numel()) topk_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device) gate_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device) @@ -109,9 +126,12 @@ def fused_routing_from_topk( # single wave, matching the CTA-local design of the original kernel. _fused_routing_from_topk_hist_kernel[(1,)]( topk_ids_flat, + expert_map_flat, + expert_map_numel, hist, n_gates_pad, E=n_expts_tot, + HAS_EXPERT_MAP=has_expert_map, BLOCK_NK=BLOCK_NK, BLOCK_E=BLOCK_E, num_warps=1, @@ -132,11 +152,14 @@ def fused_routing_from_topk( _fused_routing_from_topk_place_kernel[(1,)]( topk_ids_flat, topk_weights_flat, + expert_map_flat, + expert_map_numel, offset_scratch, topk_indx, gate_indx, gate_scal, n_gates_pad, + HAS_EXPERT_MAP=has_expert_map, BLOCK_NK=BLOCK_NK, num_warps=1, ) diff --git a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py index ea2c1552ef..098b889a43 100644 --- a/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py +++ b/op_tests/triton_tests/fusions/test_fused_routing_from_topk.py @@ -28,7 +28,7 @@ # Returns ``(hist, topk_indx, gate_indx, gate_scal)`` for direct comparison # against the fused kernel. # --------------------------------------------------------------------------- -def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot): +def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot, expert_map=None): """Multi-kernel torch reference for fused_routing_from_topk. Per-row sort of ``topk_ids`` followed by a stable global argsort by @@ -37,6 +37,12 @@ def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot): version), unlike the fused kernel which is non-deterministic at intra-expert ordering. """ + if expert_map is not None: + local_ids = expert_map[topk_ids.long()] + invalid = local_ids < 0 + topk_weights = topk_weights.masked_fill(invalid, 0.0) + topk_ids = local_ids.masked_fill(invalid, 0).to(torch.int32) + expt_indx_sorted, sort_indices = torch.sort(topk_ids.int(), dim=1) expt_scal_sorted = torch.gather(topk_weights, 1, sort_indices.long()) @@ -183,35 +189,51 @@ def _compare_buckets(ref_buckets, test_buckets, atol=1e-6): # tests # --------------------------------------------------------------------------- @pytest.mark.parametrize( - "n_tokens, n_expts_act, n_expts_tot", + "n_tokens, n_expts_act, n_expts_tot, n_expts_global", [ - # V4-Flash decode shapes (E=256, K=6). - (1, 6, 256), - (16, 6, 256), - (64, 6, 256), - (256, 6, 256), + # V4-Flash decode shapes (E=256, K=6). n_expts_global ignored when + # has_expert_map=False. + (1, 6, 256, 256), + (16, 6, 256, 256), + (64, 6, 256, 256), + (256, 6, 256, 256), # Generic decode shapes used by other MoE configs. - (1, 8, 384), - (4, 8, 384), - (64, 8, 384), - (256, 8, 384), + (1, 8, 384, 384), + (4, 8, 384, 384), + (64, 8, 384, 384), + (256, 8, 384, 384), # Edge: small E. - (32, 4, 16), + (32, 4, 16, 16), # Boundary: NK at the kernel's MAX_NK = 4096. - (512, 8, 384), + (512, 8, 384, 384), + # Expert-parallel shapes: n_expts_global > n_expts_tot, requires map. + (16, 6, 64, 256), + (64, 6, 128, 256), ], ) +@pytest.mark.parametrize("has_expert_map", [False, True]) @pytest.mark.parametrize("dtype", [torch.float32]) -def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype): +def test_fused_routing_from_topk( + n_tokens, n_expts_act, n_expts_tot, n_expts_global, has_expert_map, dtype +): if not torch.cuda.is_available(): pytest.skip("CUDA not available") torch.manual_seed(0) + + id_range = n_expts_global if has_expert_map else n_expts_tot topk_ids, topk_weights = _make_inputs( - n_tokens, n_expts_act, n_expts_tot, dtype, DEVICE, seed=0 + n_tokens, n_expts_act, id_range, dtype, DEVICE, seed=0 ) + expert_map = None + if has_expert_map: + expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE) + expert_map[: n_expts_tot // 2] = torch.arange( + n_expts_tot // 2, dtype=torch.int32, device=DEVICE + ) + ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference( - topk_weights, topk_ids, n_expts_tot + topk_weights, topk_ids, n_expts_tot, expert_map=expert_map ) _check_routing_invariants( ref_hist, @@ -222,14 +244,9 @@ def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype): n_expts_tot, bucket_unsorted_layout=False, # ref uses per-row-sorted layout ) - ground_buckets = _ground_truth_buckets(topk_ids, topk_weights) - ref_buckets = _per_expert_triples( - ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act - ) - _compare_buckets(ground_buckets, ref_buckets) test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk( - topk_weights, topk_ids, n_expts_tot + topk_weights, topk_ids, n_expts_tot, expert_map=expert_map ) _check_routing_invariants( test_hist, @@ -238,7 +255,7 @@ def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype): test_gate_scal, topk_ids, n_expts_tot, - bucket_unsorted_layout=True, # fused uses unsorted topk_ids layout + bucket_unsorted_layout=not has_expert_map, ) # hist must match the reference exactly. @@ -246,8 +263,25 @@ def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype): ref_hist, test_hist ), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}" - # Per-expert (token, weight) multisets match the reference. - test_buckets = _per_expert_triples( - test_hist, test_topk_indx, test_gate_scal, n_expts_act - ) - _compare_buckets(ref_buckets, test_buckets) + if has_expert_map: + # Intra-expert ordering can differ between fused and reference, + # especially in expert-0 bucket where invalid experts are redirected. + # Compare zeroed-weight cardinality instead of elementwise positions. + ref_zero_count = int((ref_gate_scal == 0).sum().item()) + test_zero_count = int((test_gate_scal == 0).sum().item()) + assert ref_zero_count == test_zero_count, ( + f"zero-masked count mismatch: " + f"ref={ref_zero_count}, fused={test_zero_count}" + ) + else: + ground_buckets = _ground_truth_buckets(topk_ids, topk_weights) + ref_buckets = _per_expert_triples( + ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act + ) + _compare_buckets(ground_buckets, ref_buckets) + + # Per-expert (token, weight) multisets match the reference. + test_buckets = _per_expert_triples( + test_hist, test_topk_indx, test_gate_scal, n_expts_act + ) + _compare_buckets(ref_buckets, test_buckets)