Skip to content

[NPU] Add NPU Fused MoE kernel#1183

Merged
Tcc0403 merged 6 commits intolinkedin:mainfrom
zheliuyu:main
Apr 24, 2026
Merged

[NPU] Add NPU Fused MoE kernel#1183
Tcc0403 merged 6 commits intolinkedin:mainfrom
zheliuyu:main

Conversation

@zheliuyu
Copy link
Copy Markdown
Contributor

@zheliuyu zheliuyu commented Apr 3, 2026

Motivation

This pr ports fused_moe.py and fused_moe_kernels.py to an NPU-affine implementation while preserving the original math. The computational definition is unchanged: forward remains W1 (gate/up) -> SwiGLU -> W2 -> token-weighted gather, and backward still follows dA' = dO @ W2^T to produce d_pre_act / dS / dW2 / dX / dW1.
The main changes are execution-strategy optimizations for NPU.

Note: Use the Skill

For this fused_moe kernel migration, we followed the skill document from #1197.

Testing Done

  • Hardware Type: Ascend 910B2
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

🤖 Generated with: cursor.

@zheliuyu zheliuyu changed the title [WIP] Support npu fused moe [NPU] Add NPU Fused MoE kernel Apr 21, 2026
@zheliuyu zheliuyu reopened this Apr 21, 2026
@zheliuyu
Copy link
Copy Markdown
Contributor Author

zheliuyu commented Apr 21, 2026

Test

We also added all shapes from benchmark_fused_moe.py to test_fused_moe.py to validate kernel generalization more broadly.

Extra code

# Benchmark num_tokens sweep in benchmark/scripts/benchmark_fused_moe.py.
BENCHMARK_T_VALUES = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]

@pytest.mark.parametrize("T", BENCHMARK_T_VALUES)
def test_benchmark_token_shapes_forward_smoke_bfloat16(T):
    """Smoke-test benchmark token sweep shapes on fused_moe forward path."""
    dtype = torch.bfloat16
    # Keep dimensions modest so CI runtime remains bounded while covering T sweep.
    E, H, intermediate_dim, K = 16, 256, 128, 8
    x, gate_up_proj, down_proj, top_k_index, top_k_weights = _make_inputs(T, E, H, intermediate_dim, K, dtype, device)
    out = LigerFusedMoEFunction.apply(x, gate_up_proj, down_proj, top_k_index, top_k_weights)
    assert out.shape == (T, H)
    assert torch.isfinite(out).all()


@pytest.mark.parametrize("T", BENCHMARK_T_VALUES)
def test_benchmark_token_shapes_forward_backward_correctness_float32(T):
    """Validate forward/backward numerical correctness on benchmark token shapes."""
    dtype = torch.float32
    # Keep dimensions small enough for full-shape gradient checks to finish quickly.
    E, H, intermediate_dim, K = 8, 64, 32, 2

    x, gate_up_proj, down_proj, top_k_index, top_k_weights = _make_inputs(T, E, H, intermediate_dim, K, dtype, device)

    ref = _reference_moe_forward(x, gate_up_proj, down_proj, top_k_index, top_k_weights)
    out = LigerFusedMoEFunction.apply(x, gate_up_proj, down_proj, top_k_index, top_k_weights)
    torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-4)

    x1 = x.clone().requires_grad_(True)
    gup1 = gate_up_proj.clone().requires_grad_(True)
    dn1 = down_proj.clone().requires_grad_(True)
    wts1 = top_k_weights.clone().requires_grad_(True)
    x2 = x.clone().requires_grad_(True)
    gup2 = gate_up_proj.clone().requires_grad_(True)
    dn2 = down_proj.clone().requires_grad_(True)
    wts2 = top_k_weights.clone().requires_grad_(True)

    out_ref = _reference_moe_forward(x1, gup1, dn1, top_k_index, wts1)
    out_ref.sum().backward()
    out_fused = LigerFusedMoEFunction.apply(x2, gup2, dn2, top_k_index, wts2)
    out_fused.sum().backward()

    b_atol, b_rtol = 3e-3, 1e-2
    torch.testing.assert_close(wts2.grad, wts1.grad, atol=b_atol, rtol=b_rtol)
    torch.testing.assert_close(dn2.grad, dn1.grad, atol=b_atol, rtol=b_rtol)
    torch.testing.assert_close(x2.grad, x1.grad, atol=b_atol, rtol=b_rtol)
    torch.testing.assert_close(gup2.grad, gup1.grad, atol=b_atol, rtol=b_rtol)

Pytest results

==================================================================================== test session starts ====================================================================================
platform linux -- Python 3.10.8, pytest-9.0.2, pluggy-1.6.0
rootdir: /root/Liger-Kernel-dev
configfile: pyproject.toml
plugins: xdist-3.8.0, cov-7.0.0, rerunfailures-16.1, asyncio-1.3.0, anyio-4.9.0
asyncio: mode=auto, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 37 items                                                                                                                                                                          

test/transformers/test_fused_moe.py::test_routing_metadata_invariants[64-8-2] PASSED                                                                                                  [  2%]
test/transformers/test_fused_moe.py::test_routing_metadata_invariants[128-16-4] PASSED                                                                                                [  5%]
test/transformers/test_fused_moe.py::test_routing_metadata_invariants[100-8-2] PASSED                                                                                                 [  8%]
test/transformers/test_fused_moe.py::test_routing_metadata_invariants[256-32-8] PASSED                                                                                                [ 10%]
test/transformers/test_fused_moe.py::test_correctness[dtype0-0.001-0.0001-7-4-64-32-2] PASSED                                                                                         [ 13%]
test/transformers/test_fused_moe.py::test_correctness[dtype0-0.001-0.0001-512-8-256-128-2] PASSED                                                                                     [ 16%]
test/transformers/test_fused_moe.py::test_correctness[dtype0-0.001-0.0001-512-8-97-47-2] PASSED                                                                                       [ 18%]
test/transformers/test_fused_moe.py::test_correctness[dtype0-0.001-0.0001-512-7-128-64-3] PASSED                                                                                      [ 21%]
test/transformers/test_fused_moe.py::test_correctness[dtype0-0.001-0.0001-512-8-256-64-1] PASSED                                                                                      [ 24%]
test/transformers/test_fused_moe.py::test_correctness[dtype0-0.001-0.0001-128-8-256-64-8] PASSED                                                                                      [ 27%]
test/transformers/test_fused_moe.py::test_correctness[dtype1-0.1-0.01-7-4-64-32-2] PASSED                                                                                             [ 29%]
test/transformers/test_fused_moe.py::test_correctness[dtype1-0.1-0.01-512-8-256-128-2] PASSED                                                                                         [ 32%]
test/transformers/test_fused_moe.py::test_correctness[dtype1-0.1-0.01-512-8-97-47-2] PASSED                                                                                           [ 35%]
test/transformers/test_fused_moe.py::test_correctness[dtype1-0.1-0.01-512-7-128-64-3] PASSED                                                                                          [ 37%]
test/transformers/test_fused_moe.py::test_correctness[dtype1-0.1-0.01-512-8-256-64-1] PASSED                                                                                          [ 40%]
test/transformers/test_fused_moe.py::test_correctness[dtype1-0.1-0.01-128-8-256-64-8] PASSED                                                                                          [ 43%]
test/transformers/test_fused_moe.py::test_all_tokens_to_one_expert PASSED                                                                                                             [ 45%]
test/transformers/test_fused_moe.py::test_single_token PASSED                                                                                                                         [ 48%]
test/transformers/test_fused_moe.py::test_K_equals_E PASSED                                                                                                                           [ 51%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[128] PASSED                                                                                   [ 54%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[256] PASSED                                                                                   [ 56%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[512] PASSED                                                                                   [ 59%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[1024] PASSED                                                                                  [ 62%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[2048] PASSED                                                                                  [ 64%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[4096] PASSED                                                                                  [ 67%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[8192] PASSED                                                                                  [ 70%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[16384] PASSED                                                                                 [ 72%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_smoke_bfloat16[32768] PASSED                                                                                 [ 75%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[128] PASSED                                                                     [ 78%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[256] PASSED                                                                     [ 81%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[512] PASSED                                                                     [ 83%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[1024] PASSED                                                                    [ 86%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[2048] PASSED                                                                    [ 89%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[4096] PASSED                                                                    [ 91%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[8192] PASSED                                                                    [ 94%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[16384] PASSED                                                                   [ 97%]
test/transformers/test_fused_moe.py::test_benchmark_token_shapes_forward_backward_correctness_float32[32768] PASSED                                                                   [100%]

Benchmark results

Benchmark visualization

fused_moe_speed_full_token_length (ms) fused_moe_speed_forward_token_length (ms)
fused_moe_speed_full_token_length fused_moe_speed_forward_token_length
fused_moe_speed_backward_token_length (ms) fused_moe_memory_full_token_length (MB)
fused_moe_speed_backward_token_length fused_moe_memory_full_token_length

Raw benchmark data

BENCHMARKING SPEED for FUSED_MOE
Model: qwen3_moe_30bE=128, H=2048, intermediate_dim=768, K=8, T_base=8192, dtype=torch.bfloat16
Pre-warming Liger autotune (H=2048, intermediate_dim=768)...
Autotune warmup complete.

**************************************
     BENCHMARKING SPEED for FUSED_MOE
**************************************

********** Benchmark Data **********
[
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      51.98942184448242,
      39.075538635253906,
      41.29703903198242,
      45.648860931396484,
      68.03790283203125,
      88.77642059326172,
      144.06663513183594,
      253.5931854248047,
      465.77532958984375
    ],
    "y_values_20": [
      51.98942184448242,
      39.075538635253906,
      41.29703903198242,
      45.648860931396484,
      68.03790283203125,
      88.77642059326172,
      144.06663513183594,
      253.5931854248047,
      465.77532958984375
    ],
    "y_values_80": [
      51.98942184448242,
      39.075538635253906,
      41.29703903198242,
      45.648860931396484,
      68.03790283203125,
      88.77642059326172,
      144.06663513183594,
      253.5931854248047,
      465.77532958984375
    ],
    "timestamp": "2026-04-20 20:23:01",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      728.71826171875,
      732.6692504882812,
      744.4974975585938,
      767.8123168945312,
      766.8602905273438,
      793.347900390625,
      845.9445190429688,
      956.5780029296875,
      1098.6083984375
    ],
    "y_values_20": [
      728.71826171875,
      732.6692504882812,
      744.4974975585938,
      767.8123168945312,
      766.8602905273438,
      793.347900390625,
      845.9445190429688,
      956.5780029296875,
      1098.6083984375
    ],
    "y_values_80": [
      728.71826171875,
      732.6692504882812,
      744.4974975585938,
      767.8123168945312,
      766.8602905273438,
      793.347900390625,
      845.9445190429688,
      956.5780029296875,
      1098.6083984375
    ],
    "timestamp": "2026-04-20 20:24:02",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      11.595239639282227,
      9.2535400390625,
      10.660460472106934,
      12.010939598083496,
      18.729719161987305,
      27.991579055786133,
      52.14693832397461,
      97.36244201660156,
      188.6189422607422
    ],
    "y_values_20": [
      11.595239639282227,
      9.2535400390625,
      10.660460472106934,
      12.010939598083496,
      18.729719161987305,
      27.991579055786133,
      52.14693832397461,
      97.36244201660156,
      188.6189422607422
    ],
    "y_values_80": [
      11.595239639282227,
      9.2535400390625,
      10.660460472106934,
      12.010939598083496,
      18.729719161987305,
      27.991579055786133,
      52.14693832397461,
      97.36244201660156,
      188.6189422607422
    ],
    "timestamp": "2026-04-20 20:24:06",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      168.9834442138672,
      170.3657989501953,
      177.6240234375,
      182.25244140625,
      191.5980224609375,
      208.39393615722656,
      225.64047241210938,
      272.00885009765625,
      379.7675476074219
    ],
    "y_values_20": [
      168.9834442138672,
      170.3657989501953,
      177.6240234375,
      182.25244140625,
      191.5980224609375,
      208.39393615722656,
      225.64047241210938,
      272.00885009765625,
      379.7675476074219
    ],
    "y_values_80": [
      168.9834442138672,
      170.3657989501953,
      177.6240234375,
      182.25244140625,
      191.5980224609375,
      208.39393615722656,
      225.64047241210938,
      272.00885009765625,
      379.7675476074219
    ],
    "timestamp": "2026-04-20 20:24:23",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      36.21335983276367,
      29.728300094604492,
      30.529579162597656,
      33.48516082763672,
      48.53886032104492,
      58.386600494384766,
      90.72882080078125,
      155.96798706054688,
      275.62823486328125
    ],
    "y_values_20": [
      36.21335983276367,
      29.728300094604492,
      30.529579162597656,
      33.48516082763672,
      48.53886032104492,
      58.386600494384766,
      90.72882080078125,
      155.96798706054688,
      275.62823486328125
    ],
    "y_values_80": [
      36.21335983276367,
      29.728300094604492,
      30.529579162597656,
      33.48516082763672,
      48.53886032104492,
      58.386600494384766,
      90.72882080078125,
      155.96798706054688,
      275.62823486328125
    ],
    "timestamp": "2026-04-20 20:24:30",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      550.6949462890625,
      553.7078857421875,
      560.1070556640625,
      567.1722412109375,
      575.6559448242188,
      589.7853393554688,
      617.0770263671875,
      677.5838012695312,
      714.6703491210938
    ],
    "y_values_20": [
      550.6949462890625,
      553.7078857421875,
      560.1070556640625,
      567.1722412109375,
      575.6559448242188,
      589.7853393554688,
      617.0770263671875,
      677.5838012695312,
      714.6703491210938
    ],
    "y_values_80": [
      550.6949462890625,
      553.7078857421875,
      560.1070556640625,
      567.1722412109375,
      575.6559448242188,
      589.7853393554688,
      617.0770263671875,
      677.5838012695312,
      714.6703491210938
    ],
    "timestamp": "2026-04-20 20:25:16",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  }
]
BENCHMARKING MEMORY for FUSED_MOE
**************************************
     BENCHMARKING MEMORY for FUSED_MOE
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      4997.0322265625,
      5005.048828125,
      5012.0908203125,
      5032.169921875,
      5072.328125,
      5168.70947265625,
      5617.40478515625,
      6514.79541015625,
      8313.572265625
    ],
    "y_values_20": [
      4997.0322265625,
      5005.048828125,
      5012.0908203125,
      5032.169921875,
      5072.328125,
      5168.70947265625,
      5617.40478515625,
      6514.79541015625,
      8313.572265625
    ],
    "y_values_80": [
      4997.0322265625,
      5005.048828125,
      5012.0908203125,
      5032.169921875,
      5072.328125,
      5168.70947265625,
      5617.40478515625,
      6514.79541015625,
      8313.572265625
    ],
    "timestamp": "2026-04-20 20:25:28",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      5405.087890625,
      5431.63232421875,
      5472.73046875,
      5562.9296875,
      5743.27734375,
      6111.76904296875,
      6830.26171875,
      8274.890625,
      11159.884765625
    ],
    "y_values_20": [
      5405.087890625,
      5431.63232421875,
      5472.73046875,
      5562.9296875,
      5743.27734375,
      6111.76904296875,
      6830.26171875,
      8274.890625,
      11159.884765625
    ],
    "y_values_80": [
      5405.087890625,
      5431.63232421875,
      5472.73046875,
      5562.9296875,
      5743.27734375,
      6111.76904296875,
      6830.26171875,
      8274.890625,
      11159.884765625
    ],
    "timestamp": "2026-04-20 20:26:42",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      2337.52685546875,
      2342.04541015625,
      2365.64306640625,
      2412.40185546875,
      2510.95068359375,
      2701.04833984375,
      3086.93115234375,
      3856.07177734375,
      5398.56787109375
    ],
    "y_values_20": [
      2337.52685546875,
      2342.04541015625,
      2365.64306640625,
      2412.40185546875,
      2510.95068359375,
      2701.04833984375,
      3086.93115234375,
      3856.07177734375,
      5398.56787109375
    ],
    "y_values_80": [
      2337.52685546875,
      2342.04541015625,
      2365.64306640625,
      2412.40185546875,
      2510.95068359375,
      2701.04833984375,
      3086.93115234375,
      3856.07177734375,
      5398.56787109375
    ],
    "timestamp": "2026-04-20 20:26:47",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      2343.51416015625,
      2365.52001953125,
      2408.53271484375,
      2496.556640625,
      2672.60546875,
      3030.6533203125,
      3734.263671875,
      5142.11279296875,
      7958.04248046875
    ],
    "y_values_20": [
      2343.51416015625,
      2365.52001953125,
      2408.53271484375,
      2496.556640625,
      2672.60546875,
      3030.6533203125,
      3734.263671875,
      5142.11279296875,
      7958.04248046875
    ],
    "y_values_80": [
      2343.51416015625,
      2365.52001953125,
      2408.53271484375,
      2496.556640625,
      2672.60546875,
      3030.6533203125,
      3734.263671875,
      5142.11279296875,
      7958.04248046875
    ],
    "timestamp": "2026-04-20 20:27:08",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      4997.0322265625,
      5005.048828125,
      5012.0908203125,
      5032.169921875,
      5072.328125,
      5168.70947265625,
      5617.40478515625,
      6514.79541015625,
      8313.572265625
    ],
    "y_values_20": [
      4997.0322265625,
      5005.048828125,
      5012.0908203125,
      5032.169921875,
      5072.328125,
      5168.70947265625,
      5617.40478515625,
      6514.79541015625,
      8313.572265625
    ],
    "y_values_80": [
      4997.0322265625,
      5005.048828125,
      5012.0908203125,
      5032.169921875,
      5072.328125,
      5168.70947265625,
      5617.40478515625,
      6514.79541015625,
      8313.572265625
    ],
    "timestamp": "2026-04-20 20:27:13",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_moe",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B2",
    "x_name": "T",
    "x_label": "num_tokens",
    "x_values": [
      128,
      256,
      512,
      1024,
      2048,
      4096,
      8192,
      16384,
      32768
    ],
    "y_values_50": [
      5405.087890625,
      5431.63232421875,
      5472.73046875,
      5562.9296875,
      5743.27734375,
      6110.7744140625,
      6830.7568359375,
      8272.939453125,
      11157.9033203125
    ],
    "y_values_20": [
      5405.087890625,
      5431.63232421875,
      5472.73046875,
      5562.9296875,
      5743.27734375,
      6110.7744140625,
      6830.7568359375,
      8272.939453125,
      11157.9033203125
    ],
    "y_values_80": [
      5405.087890625,
      5431.63232421875,
      5472.73046875,
      5562.9296875,
      5743.27734375,
      6110.7744140625,
      6830.7568359375,
      8272.939453125,
      11157.9033203125
    ],
    "timestamp": "2026-04-20 20:28:12",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"sweep_dim\": \"T\", \"T\": null, \"E\": 128, \"H\": 2048, \"intermediate_dim\": 768, \"K\": 8, \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  }
]

@zheliuyu zheliuyu marked this pull request as ready for review April 21, 2026 06:55
@zheliuyu
Copy link
Copy Markdown
Contributor Author

@Tcc0403 This PR is ready for review.

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a tiny issue

Comment on lines -234 to +241
torch.cuda.synchronize()
if device == "cuda":
torch.cuda.synchronize()
elif device == "npu":
torch.npu.synchronize()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, we also have CPU support. Could you add it?

Comment on lines +160 to +165
if device == "cuda":
torch.cuda.synchronize()
elif device == "npu":
torch.npu.synchronize()
else:
torch.cpu.synchronize()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tcc0403 Thanks for the suggestion. torch provides a cpu equivalent, so I've added it here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry typo, meant to be xpu not cpu 😅

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Added torch.xpu.synchronize, please take another look.

Comment on lines +28 to +47
def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token: int = BLOCK_M_TOKEN):
"""Compute token→expert routing permutation metadata via 3 Triton kernels.

Also computes GPU tile metadata (tile_row_start, tile_expert) inside
Kernel 3 — no CPU loop, one .item() sync for num_m_tiles allocation.

Args:
topk_indices: (T, K) int32 — pre-computed top-k expert indices per token
E: number of experts
block_m_token: BLOCK_M for token-dimension tiling (default BLOCK_M_TOKEN)

Returns:
expert_token_count: (E,) int32
expert_start_idx: (E+1,) int32
x_gather_idx: (TK,) int32
s_scatter_idx: (TK,) int32
s_reverse_scatter_idx: (TK,) int32
tile_row_start: (num_m_tiles,) int32 — absolute row_start per M-tile
tile_expert: (num_m_tiles,) int32 — expert index per M-tile
"""
Copy link
Copy Markdown
Contributor Author

@zheliuyu zheliuyu Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supplementary testing

Adding the Fused MoE kernel also tweaked the SwiGLU test, which currently breaks test_swiglu.py on NPU. The NPU Fused MoE kernel will fix this.

Before

Image

After

Image

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm

@Tcc0403 Tcc0403 added this pull request to the merge queue Apr 24, 2026
Merged via the queue into linkedin:main with commit e4831e4 Apr 24, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants