Skip to content

Commit c651c8a

Browse files
committed
update_with_shao-chun_comments
1 parent e255e6a commit c651c8a

2 files changed

Lines changed: 65 additions & 94 deletions

File tree

aiter/ops/triton/fusions/fused_routing_from_topk.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ def fused_routing_from_topk(
4242
4343
Args:
4444
topk_weights: ``[n_tokens, n_expts_act]`` per-token routing weights.
45+
Must be contiguous.
4546
topk_ids: ``[n_tokens, n_expts_act]`` selected expert ids; values
46-
in ``[0, n_expts_tot)``.
47+
in ``[0, n_expts_tot)``. Must be contiguous int32.
4748
n_expts_tot: Total number of routed experts (= ``E``).
4849
expert_map: Optional global→local expert map. When provided,
4950
``topk_ids`` are treated as global ids and remapped inside fused
5051
kernels. Entries mapped to ``< 0`` are masked to zero weight and
51-
redirected to local expert ``0`` for routing safety.
52+
redirected to local expert ``0`` for routing safety. Must be
53+
contiguous int32 when provided.
5254
5355
Returns:
5456
Tuple ``(hist, topk_indx, gate_indx, gate_scal)``:
@@ -94,16 +96,20 @@ def fused_routing_from_topk(
9496
device = topk_weights.device
9597
weights_dtype = topk_weights.dtype
9698

97-
# Triton kernel needs flat int32 inputs. .reshape on a contiguous tensor
98-
# is a view; .contiguous() / .to(int32) on already-canonical tensors
99-
# are no-ops.
100-
topk_ids_flat = topk_ids.contiguous().reshape(-1).to(torch.int32)
101-
topk_weights_flat = topk_weights.contiguous().reshape(-1)
99+
assert (
100+
topk_ids.is_contiguous() and topk_ids.dtype == torch.int32
101+
), "topk_ids must be contiguous int32"
102+
assert topk_weights.is_contiguous(), "topk_weights must be contiguous"
103+
topk_ids_flat = topk_ids.reshape(-1)
104+
topk_weights_flat = topk_weights.reshape(-1)
102105
expert_map_numel = 0
103106
expert_map_flat = topk_ids_flat
104107
has_expert_map = expert_map is not None
105108
if has_expert_map:
106-
expert_map_flat = expert_map.contiguous().reshape(-1).to(torch.int32)
109+
assert (
110+
expert_map.is_contiguous() and expert_map.dtype == torch.int32
111+
), "expert_map must be contiguous int32"
112+
expert_map_flat = expert_map.reshape(-1)
107113
expert_map_numel = int(expert_map_flat.numel())
108114

109115
topk_indx = torch.empty(n_gates_pad, dtype=torch.int32, device=device)

op_tests/triton_tests/fusions/test_fused_routing_from_topk.py

Lines changed: 51 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -189,98 +189,48 @@ def _compare_buckets(ref_buckets, test_buckets, atol=1e-6):
189189
# tests
190190
# ---------------------------------------------------------------------------
191191
@pytest.mark.parametrize(
192-
"n_tokens, n_expts_act, n_expts_tot",
192+
"n_tokens, n_expts_act, n_expts_tot, n_expts_global",
193193
[
194-
# V4-Flash decode shapes (E=256, K=6).
195-
(1, 6, 256),
196-
(16, 6, 256),
197-
(64, 6, 256),
198-
(256, 6, 256),
194+
# V4-Flash decode shapes (E=256, K=6). n_expts_global ignored when
195+
# has_expert_map=False.
196+
(1, 6, 256, 256),
197+
(16, 6, 256, 256),
198+
(64, 6, 256, 256),
199+
(256, 6, 256, 256),
199200
# Generic decode shapes used by other MoE configs.
200-
(1, 8, 384),
201-
(4, 8, 384),
202-
(64, 8, 384),
203-
(256, 8, 384),
201+
(1, 8, 384, 384),
202+
(4, 8, 384, 384),
203+
(64, 8, 384, 384),
204+
(256, 8, 384, 384),
204205
# Edge: small E.
205-
(32, 4, 16),
206+
(32, 4, 16, 16),
206207
# Boundary: NK at the kernel's MAX_NK = 4096.
207-
(512, 8, 384),
208-
],
209-
)
210-
@pytest.mark.parametrize("dtype", [torch.float32])
211-
def test_fused_routing_from_topk(n_tokens, n_expts_act, n_expts_tot, dtype):
212-
if not torch.cuda.is_available():
213-
pytest.skip("CUDA not available")
214-
torch.manual_seed(0)
215-
topk_ids, topk_weights = _make_inputs(
216-
n_tokens, n_expts_act, n_expts_tot, dtype, DEVICE, seed=0
217-
)
218-
219-
ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference(
220-
topk_weights, topk_ids, n_expts_tot
221-
)
222-
_check_routing_invariants(
223-
ref_hist,
224-
ref_topk_indx,
225-
ref_gate_indx,
226-
ref_gate_scal,
227-
topk_ids,
228-
n_expts_tot,
229-
bucket_unsorted_layout=False, # ref uses per-row-sorted layout
230-
)
231-
ground_buckets = _ground_truth_buckets(topk_ids, topk_weights)
232-
ref_buckets = _per_expert_triples(
233-
ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act
234-
)
235-
_compare_buckets(ground_buckets, ref_buckets)
236-
237-
test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk(
238-
topk_weights, topk_ids, n_expts_tot
239-
)
240-
_check_routing_invariants(
241-
test_hist,
242-
test_topk_indx,
243-
test_gate_indx,
244-
test_gate_scal,
245-
topk_ids,
246-
n_expts_tot,
247-
bucket_unsorted_layout=True, # fused uses unsorted topk_ids layout
248-
)
249-
250-
# hist must match the reference exactly.
251-
assert torch.equal(
252-
ref_hist, test_hist
253-
), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}"
254-
255-
# Per-expert (token, weight) multisets match the reference.
256-
test_buckets = _per_expert_triples(
257-
test_hist, test_topk_indx, test_gate_scal, n_expts_act
258-
)
259-
_compare_buckets(ref_buckets, test_buckets)
260-
261-
262-
@pytest.mark.parametrize(
263-
"n_tokens, n_expts_act, n_expts_tot,n_expts_global",
264-
[
208+
(512, 8, 384, 384),
209+
# Expert-parallel shapes: n_expts_global > n_expts_tot, requires map.
265210
(16, 6, 64, 256),
266211
(64, 6, 128, 256),
267212
],
268213
)
214+
@pytest.mark.parametrize("has_expert_map", [False, True])
269215
@pytest.mark.parametrize("dtype", [torch.float32])
270-
def test_fused_routing_from_topk_with_expert_map(
271-
n_tokens, n_expts_act, n_expts_tot, n_expts_global, dtype
216+
def test_fused_routing_from_topk(
217+
n_tokens, n_expts_act, n_expts_tot, n_expts_global, has_expert_map, dtype
272218
):
273219
if not torch.cuda.is_available():
274220
pytest.skip("CUDA not available")
275221
torch.manual_seed(0)
222+
223+
id_range = n_expts_global if has_expert_map else n_expts_tot
276224
topk_ids, topk_weights = _make_inputs(
277-
n_tokens, n_expts_act, n_expts_global, dtype, DEVICE, seed=0
225+
n_tokens, n_expts_act, id_range, dtype, DEVICE, seed=0
278226
)
279227

280-
expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE)
281-
expert_map[: n_expts_tot // 2] = torch.arange(
282-
n_expts_tot // 2, dtype=torch.int32, device=DEVICE
283-
)
228+
expert_map = None
229+
if has_expert_map:
230+
expert_map = torch.full((n_expts_global,), -1, dtype=torch.int32, device=DEVICE)
231+
expert_map[: n_expts_tot // 2] = torch.arange(
232+
n_expts_tot // 2, dtype=torch.int32, device=DEVICE
233+
)
284234

285235
ref_hist, ref_topk_indx, ref_gate_indx, ref_gate_scal = routing_from_topk_reference(
286236
topk_weights, topk_ids, n_expts_tot, expert_map=expert_map
@@ -292,7 +242,7 @@ def test_fused_routing_from_topk_with_expert_map(
292242
ref_gate_scal,
293243
topk_ids,
294244
n_expts_tot,
295-
bucket_unsorted_layout=False,
245+
bucket_unsorted_layout=False, # ref uses per-row-sorted layout
296246
)
297247

298248
test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk(
@@ -305,18 +255,33 @@ def test_fused_routing_from_topk_with_expert_map(
305255
test_gate_scal,
306256
topk_ids,
307257
n_expts_tot,
308-
bucket_unsorted_layout=False,
258+
bucket_unsorted_layout=not has_expert_map,
309259
)
310260

261+
# hist must match the reference exactly.
311262
assert torch.equal(
312263
ref_hist, test_hist
313264
), f"hist mismatch:\n ref={ref_hist}\n fused={test_hist}"
314265

315-
# Intra-expert ordering can differ between fused and reference,
316-
# especially in expert-0 bucket where invalid experts are redirected.
317-
# Compare zeroed-weight cardinality instead of elementwise positions.
318-
ref_zero_count = int((ref_gate_scal == 0).sum().item())
319-
test_zero_count = int((test_gate_scal == 0).sum().item())
320-
assert (
321-
ref_zero_count == test_zero_count
322-
), f"zero-masked count mismatch: ref={ref_zero_count}, fused={test_zero_count}"
266+
if has_expert_map:
267+
# Intra-expert ordering can differ between fused and reference,
268+
# especially in expert-0 bucket where invalid experts are redirected.
269+
# Compare zeroed-weight cardinality instead of elementwise positions.
270+
ref_zero_count = int((ref_gate_scal == 0).sum().item())
271+
test_zero_count = int((test_gate_scal == 0).sum().item())
272+
assert ref_zero_count == test_zero_count, (
273+
f"zero-masked count mismatch: "
274+
f"ref={ref_zero_count}, fused={test_zero_count}"
275+
)
276+
else:
277+
ground_buckets = _ground_truth_buckets(topk_ids, topk_weights)
278+
ref_buckets = _per_expert_triples(
279+
ref_hist, ref_topk_indx, ref_gate_scal, n_expts_act
280+
)
281+
_compare_buckets(ground_buckets, ref_buckets)
282+
283+
# Per-expert (token, weight) multisets match the reference.
284+
test_buckets = _per_expert_triples(
285+
test_hist, test_topk_indx, test_gate_scal, n_expts_act
286+
)
287+
_compare_buckets(ref_buckets, test_buckets)

0 commit comments

Comments
 (0)