Skip to content

[TRITON] moe routing support expert_map for expert parallelism#3348

Open
amd-ruitang3 wants to merge 3 commits into
mainfrom
triton_moe_routing_expert_parallel
Open

[TRITON] moe routing support expert_map for expert parallelism#3348
amd-ruitang3 wants to merge 3 commits into
mainfrom
triton_moe_routing_expert_parallel

Conversation

@amd-ruitang3
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3348 --add-label <label>

@amd-ruitang3 amd-ruitang3 requested a review from k50112113 May 26, 2026 06:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the Triton fused “topk → routing data” path to support expert-parallel setups via an optional expert_map that remaps global expert IDs to local expert IDs, and adds a corresponding correctness test.

Changes:

  • Add expert_map: Optional[torch.Tensor] to fused_routing_from_topk and plumb it into the Triton histogram/place kernels.
  • Implement expert-map remapping + invalid-expert masking (redirect to expert 0, weight=0) inside the fused Triton kernels.
  • Extend the torch reference path and add a new test covering expert_map behavior.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
op_tests/triton_tests/fusions/test_fused_routing_from_topk.py Adds expert-map-aware reference logic and a new test case for remapped routing.
aiter/ops/triton/fusions/fused_routing_from_topk.py Extends the Python wrapper API to accept expert_map and pass it to Triton kernels.
aiter/ops/triton/_triton_kernels/fusions/fused_routing_from_topk.py Adds expert-map remapping logic to the Triton histogram and placement kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +45 to +51
topk_ids: ``[n_tokens, n_expts_act]`` selected expert ids; values
in ``[0, n_expts_tot)``.
n_expts_tot: Total number of routed experts (= ``E``).
expert_map: Optional global→local expert map. When provided,
``topk_ids`` are treated as global ids and remapped inside fused
kernels. Entries mapped to ``< 0`` are masked to zero weight and
redirected to local expert ``0`` for routing safety.
expert_map_numel = 0
expert_map_flat = topk_ids_flat
has_expert_map = expert_map is not None
if has_expert_map:
)
# Match reference semantics: invalid experts are redirected to bucket 0
# and later zeroed in gate_scal.
expt = tl.where(local_expt >= 0, local_expt, 0)
local_expt = tl.load(expert_map_ptr + safe_global_expt, mask=map_mask, other=-1).to(
tl.int32
)
invalid = local_expt < 0
"""
if expert_map is not None:
local_ids = expert_map[topk_ids.long()]
invalid = local_ids < 0
Comment on lines +298 to +309
test_hist, test_topk_indx, test_gate_indx, test_gate_scal = fused_routing_from_topk(
topk_weights, topk_ids, n_expts_tot, expert_map=expert_map
)
_check_routing_invariants(
test_hist,
test_topk_indx,
test_gate_indx,
test_gate_scal,
topk_ids,
n_expts_tot,
bucket_unsorted_layout=False,
)
Comment thread aiter/ops/triton/fusions/fused_routing_from_topk.py Outdated
Comment thread aiter/ops/triton/fusions/fused_routing_from_topk.py Outdated
Comment thread aiter/ops/triton/fusions/fused_routing_from_topk.py Outdated
Comment thread op_tests/triton_tests/fusions/test_fused_routing_from_topk.py Outdated
Copy link
Copy Markdown
Contributor

@k50112113 k50112113 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks for the addition

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants