Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"_fused_routing_from_topk_hist_kernel",
[
"E",
"HAS_EXPERT_MAP",
"BLOCK_NK",
"BLOCK_E",
],
Expand All @@ -31,6 +32,7 @@
_fused_routing_from_topk_place_kernel_repr = make_kernel_repr(
"_fused_routing_from_topk_place_kernel",
[
"HAS_EXPERT_MAP",
"BLOCK_NK",
],
)
Expand All @@ -40,11 +42,14 @@
def _fused_routing_from_topk_hist_kernel(
# inputs
topk_ids_ptr, # [NK] int32 — flattened topk_ids
expert_map_ptr, # [N_EXPERTS_GLOBAL] int32 or identity map fallback
expert_map_numel, # runtime int — bounds for expert_map_ptr
# outputs
hist_ptr, # [E] int32 — tokens-per-expert histogram
# shapes
NK, # runtime int — actual valid item count (≤ BLOCK_NK)
E: tl.constexpr,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_NK: tl.constexpr, # padded to next pow2 of NK
BLOCK_E: tl.constexpr, # padded to next pow2 of E (tl.histogram needs pow2)
):
Expand All @@ -62,7 +67,20 @@ def _fused_routing_from_topk_hist_kernel(
# Clamp the offset for masked-out lanes to 0 so the pointer arithmetic
# below stays within the allocated buffers.
safe_item = tl.where(item_mask, item_offs, 0)
expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32)
global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(
tl.int32
)
if HAS_EXPERT_MAP:
map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel)
safe_global_expt = tl.where(map_mask, global_expt, 0)
local_expt = tl.load(
expert_map_ptr + safe_global_expt, mask=map_mask, other=-1
).to(tl.int32)
# Match reference semantics: invalid experts are redirected to bucket 0
# and later zeroed in gate_scal.
expt = tl.where(local_expt >= 0, local_expt, 0)
else:
expt = global_expt

hist = tl.histogram(expt, BLOCK_E, mask=item_mask)

Expand Down Expand Up @@ -99,13 +117,16 @@ def _fused_routing_from_topk_place_kernel(
# inputs
topk_ids_ptr, # [NK] int32 — flattened topk_ids
topk_weights_ptr, # [NK] (any float dtype) — flattened topk_weights
expert_map_ptr, # [N_EXPERTS_GLOBAL] int32 or identity map fallback
expert_map_numel, # runtime int — bounds for expert_map_ptr
offset_ptr, # [E] int32 — exclusive prefix sums from the offset kernel
# outputs
topk_indx_ptr, # [NK] int32 — output gather_indx.src_indx
gate_indx_ptr, # [NK] int32 — output gather_indx.dst_indx
gate_scal_ptr, # [NK] same dtype as topk_weights
# shapes
NK, # runtime int — actual valid item count (≤ BLOCK_NK)
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_NK: tl.constexpr, # padded to next pow2 of NK
):
"""Phase C: place items.
Expand All @@ -122,8 +143,21 @@ def _fused_routing_from_topk_place_kernel(
item_offs = tl.arange(0, BLOCK_NK)
item_mask = item_offs < NK
safe_item = tl.where(item_mask, item_offs, 0)
expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(tl.int32)
global_expt = tl.load(topk_ids_ptr + safe_item, mask=item_mask, other=0).to(
tl.int32
)
weights = tl.load(topk_weights_ptr + safe_item, mask=item_mask, other=0.0)
if HAS_EXPERT_MAP:
map_mask = item_mask & (global_expt >= 0) & (global_expt < expert_map_numel)
safe_global_expt = tl.where(map_mask, global_expt, 0)
local_expt = tl.load(
expert_map_ptr + safe_global_expt, mask=map_mask, other=-1
).to(tl.int32)
invalid = local_expt < 0
expt = tl.where(invalid, 0, local_expt)
weights = tl.where(invalid, 0.0, weights)
else:
expt = global_expt

pos = tl.atomic_add(offset_ptr + expt, 1, mask=item_mask)

Expand Down
37 changes: 30 additions & 7 deletions aiter/ops/triton/fusions/fused_routing_from_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Fused replacement for the multi-kernel "topk → routing data" chain that
# bridges FusedMoE.select_experts to triton_kernels.matmul_ogs. See the
# accompanying _triton_kernels/fused_routing_from_topk.py for the kernel.
from typing import Tuple
from typing import Optional, Tuple

import torch
import triton
Expand All @@ -30,6 +30,7 @@ def fused_routing_from_topk(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
n_expts_tot: int,
expert_map: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sort (token, slot) pairs by their expert id via a single Triton kernel.

Expand All @@ -41,9 +42,15 @@ def fused_routing_from_topk(

Args:
topk_weights: ``[n_tokens, n_expts_act]`` per-token routing weights.
Must be contiguous.
topk_ids: ``[n_tokens, n_expts_act]`` selected expert ids; values
in ``[0, n_expts_tot)``.
in ``[0, n_expts_tot)``. Must be contiguous int32.
n_expts_tot: Total number of routed experts (= ``E``).
expert_map: Optional global→local expert map. When provided,
``topk_ids`` are treated as global ids and remapped inside fused
kernels. Entries mapped to ``< 0`` are masked to zero weight and
redirected to local expert ``0`` for routing safety. Must be
contiguous int32 when provided.

Returns:
Tuple ``(hist, topk_indx, gate_indx, gate_scal)``:
Expand Down Expand Up @@ -89,11 +96,21 @@ def fused_routing_from_topk(
device = topk_weights.device
weights_dtype = topk_weights.dtype

# Triton kernel needs flat int32 inputs. .reshape on a contiguous tensor
# is a view; .contiguous() / .to(int32) on already-canonical tensors
# are no-ops.
topk_ids_flat = topk_ids.contiguous().reshape(-1).to(torch.int32)
topk_weights_flat = topk_weights.contiguous().reshape(-1)
assert (
topk_ids.is_contiguous() and topk_ids.dtype == torch.int32
), "topk_ids must be contiguous int32"
assert topk_weights.is_contiguous(), "topk_weights must be contiguous"
topk_ids_flat = topk_ids.reshape(-1)
topk_weights_flat = topk_weights.reshape(-1)
expert_map_numel = 0
expert_map_flat = topk_ids_flat
has_expert_map = expert_map is not None
if has_expert_map:
assert (
expert_map.is_contiguous() and expert_map.dtype == torch.int32
), "expert_map must be contiguous int32"
expert_map_flat = expert_map.reshape(-1)
expert_map_numel = int(expert_map_flat.numel())

topk_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device)
gate_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device)
Expand All @@ -109,9 +126,12 @@ def fused_routing_from_topk(
# single wave, matching the CTA-local design of the original kernel.
_fused_routing_from_topk_hist_kernel[(1,)](
topk_ids_flat,
expert_map_flat,
expert_map_numel,
hist,
n_gates_pad,
E=n_expts_tot,
HAS_EXPERT_MAP=has_expert_map,
BLOCK_NK=BLOCK_NK,
BLOCK_E=BLOCK_E,
num_warps=1,
Expand All @@ -132,11 +152,14 @@ def fused_routing_from_topk(
_fused_routing_from_topk_place_kernel[(1,)](
topk_ids_flat,
topk_weights_flat,
expert_map_flat,
expert_map_numel,
offset_scratch,
topk_indx,
gate_indx,
gate_scal,
n_gates_pad,
HAS_EXPERT_MAP=has_expert_map,
BLOCK_NK=BLOCK_NK,
num_warps=1,
)
Expand Down
90 changes: 62 additions & 28 deletions op_tests/triton_tests/fusions/test_fused_routing_from_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# Returns ``(hist, topk_indx, gate_indx, gate_scal)`` for direct comparison
# against the fused kernel.
# ---------------------------------------------------------------------------
def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot):
def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot, expert_map=None):
"""Multi-kernel torch reference for fused_routing_from_topk.

Per-row sort of ``topk_ids`` followed by a stable global argsort by
Expand All @@ -37,6 +37,12 @@ def routing_from_topk_reference(topk_weights, topk_ids, n_expts_tot):
version), unlike the fused kernel which is non-deterministic at
intra-expert ordering.
"""
if expert_map is not None:
local_ids = expert_map[topk_ids.long()]
invalid = local_ids < 0
topk_weights = topk_weights.masked_fill(invalid, 0.0)
topk_ids = local_ids.masked_fill(invalid, 0).to(torch.int32)

expt_indx_sorted, sort_indices = torch.sort(topk_ids.int(), dim=1)
expt_scal_sorted = torch.gather(topk_weights, 1, sort_indices.long())

Expand Down Expand Up @@ -183,35 +189,51 @@ def _compare_buckets(ref_buckets, test_buckets, atol=1e-6):
# tests
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"n_tokens, n_expts_act, n_expts_tot",
"n_tokens, n_expts_act, n_expts_tot, n_expts_global",
[
# V4-Flash decode shapes (E=256, K=6).
(1, 6, 256),
(16, 6, 256),
(64, 6, 256),
(256, 6, 256),
# V4-Flash decode shapes (E=256, K=6). n_expts_global ignored when
# has_expert_map=False.
(1, 6, 256, 256),
(16, 6, 256, 256),
(64, 6, 256, 256),
(256, 6, 256, 256),
# Generic decode shapes used by other MoE configs.
(1, 8, 384),
(4, 8, 384),
(64, 8, 384),
(256, 8, 384),
(1, 8, 384, 384),
(4, 8, 384, 384),
(64, 8, 384, 384),
(256, 8, 384, 384),
# Edge: small E.
(32, 4, 16),
(32, 4, 16, 16),
# Boundary: NK at the kernel's MAX_NK = 4096.
(512, 8, 384),
(512, 8, 384, 384),
# Expert-parallel shapes: n_expts_global > n_expts_tot, requires map.
(16, 6, 64, 256),
(64, 6, 128, 256),
],
)
@pytest.mark.parametrize("has_expert_map", [False, True])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype):
def test_fused_routing_from_topk(
n_tokens, n_expts_act, n_expts_tot, n_expts_global, has_expert_map, dtype
):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
torch.manual_seed(0)

id_range = n_expts_global if has_expert_map else n_expts_tot
topk_ids, topk_weights = _make_inputs(
n_tokens, n_expts_act, n_expts_tot, dtype, DEVICE, seed=0
n_tokens, n_expts_act, id_range, dtype, DEVICE, seed=0
)

expert_map = None
if has_expert_map:
expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE)
expert_map[: n_expts_tot // 2] = torch.arange(
n_expts_tot // 2, dtype=torch.int32, device=DEVICE
)

ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference(
topk_weights, topk_ids, n_expts_tot
topk_weights, topk_ids, n_expts_tot, expert_map=expert_map
)
_check_routing_invariants(
ref_hist,
Expand All @@ -222,14 +244,9 @@ def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype):
n_expts_tot,
bucket_unsorted_layout=False, # ref uses per-row-sorted layout
)
ground_buckets = _ground_truth_buckets(topk_ids, topk_weights)
ref_buckets = _per_expert_triples(
ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act
)
_compare_buckets(ground_buckets, ref_buckets)

test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk(
topk_weights, topk_ids, n_expts_tot
topk_weights, topk_ids, n_expts_tot, expert_map=expert_map
)
_check_routing_invariants(
test_hist,
Expand All @@ -238,16 +255,33 @@ def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype):
test_gate_scal,
topk_ids,
n_expts_tot,
bucket_unsorted_layout=True, # fused uses unsorted topk_ids layout
bucket_unsorted_layout=not has_expert_map,
)
Comment on lines +248 to +259

# hist must match the reference exactly.
assert torch.equal(
ref_hist, test_hist
), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}"

# Per-expert (token, weight) multisets match the reference.
test_buckets = _per_expert_triples(
test_hist, test_topk_indx, test_gate_scal, n_expts_act
)
_compare_buckets(ref_buckets, test_buckets)
if has_expert_map:
# Intra-expert ordering can differ between fused and reference,
# especially in expert-0 bucket where invalid experts are redirected.
# Compare zeroed-weight cardinality instead of elementwise positions.
ref_zero_count = int((ref_gate_scal == 0).sum().item())
test_zero_count = int((test_gate_scal == 0).sum().item())
assert ref_zero_count == test_zero_count, (
f"zero-masked count mismatch: "
f"ref={ref_zero_count}, fused={test_zero_count}"
)
else:
ground_buckets = _ground_truth_buckets(topk_ids, topk_weights)
ref_buckets = _per_expert_triples(
ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act
)
_compare_buckets(ground_buckets, ref_buckets)

# Per-expert (token, weight) multisets match the reference.
test_buckets = _per_expert_triples(
test_hist, test_topk_indx, test_gate_scal, n_expts_act
)
_compare_buckets(ref_buckets, test_buckets)
Loading