Skip to content

Commit 1faa013

Browse files
authored
MoE Triton Kernel (#1179)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Triton implementation of a fused MoE expert kernel for training, inspired by SonicMoE: https://arxiv.org/abs/2512.14080 Related to: #958 ## Details **Known limitation:** The current implementation scales well with number of experts (more parallelism) but not with number of tokens. For models with few experts (e.g., Mixtral-8x7B) at large token counts, the eager per-expert loop is likely faster. **Convergence:** the convergence test uses wider tolerances for MoE models patched with this kernel. ## Benchmarking Results - Model: Qwen3-30B-A3B - #Tokens: 8192 | #Experts: 128 | top-K: 8 | hidden: 2048 | intermediate: 768 | Sweep Dim | #Tokens | #Experts | | :--- | --- | --- | | Memory | <img width="1000" height="600" alt="fused_moe_memory_full_token_length" src="https://github.com/user-attachments/assets/5c55e185-ff15-4d1a-9aab-c0c315f94ea0" /> | <img width="1000" height="600" alt="fused_moe_memory_full_token_length" src="https://github.com/user-attachments/assets/f274f84f-8a7e-4c67-900f-1b96ca5b80d0" /> | | Speed (Full) | <img width="1000" height="600" alt="fused_moe_speed_full_token_length" src="https://github.com/user-attachments/assets/4350afdc-2820-4748-bd0c-7d05812ddc60" /> | <img width="1000" height="600" alt="fused_moe_speed_full_token_length" src="https://github.com/user-attachments/assets/d537c312-af6d-407f-9271-962b54284cd1" /> | | Speed (Forward) | <img width="1000" height="600" alt="fused_moe_speed_forward_token_length" src="https://github.com/user-attachments/assets/8256910e-1910-4117-adcf-44cd91e6125a" /> | <img width="1000" height="600" alt="fused_moe_speed_forward_token_length" src="https://github.com/user-attachments/assets/4ece35dd-215d-4c86-a6b6-f92d60c8c22c" /> | | Speed (Backward) | <img width="1000" height="600" alt="fused_moe_speed_backward_token_length" src="https://github.com/user-attachments/assets/2a12f652-9424-48be-bce8-b5e8c1a8289d" /> | <img width="1000" height="600" alt="fused_moe_speed_backward_token_length" src="https://github.com/user-attachments/assets/69e308bb-51e0-4ec4-bb08-e6cf3f87c0ce" /> | ### Benchmarking with patching into Qwen3-30B-A3B - Parameters - Max sequence length: 32768 - Per device batch size: 1 - Gradient accumulate steps: 8 - GPU: 2 * H100_8 - Result - Processing **8.24x** more tokens per second - Training speed up: 8.19x - Slightly save memory: -1% - Eval speed up: 4.1x <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 76a0821 commit 1faa013

12 files changed

Lines changed: 2102 additions & 51 deletions

File tree

benchmark/data/all_benchmark_data.csv

Lines changed: 168 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""
2+
Benchmark: LigerFusedMoEFunction vs. HuggingFace Python loop.
3+
4+
Integrates with the Liger benchmark framework (SingleBenchmarkRunInput/Output,
5+
run_benchmarks, CSV output to all_benchmark_data.csv).
6+
7+
Usage:
8+
python benchmark_fused_moe.py # T sweep, Qwen3-MoE-30B
9+
python benchmark_fused_moe.py --sweep-dim num_experts
10+
python benchmark_fused_moe.py --overwrite
11+
"""
12+
13+
import argparse
14+
import math
15+
16+
import torch
17+
import torch.nn as nn
18+
19+
from benchmark_model_configs import QWEN3_MOE_30B
20+
from benchmark_model_configs import estimate_kernel_peak_memory
21+
from utils import SingleBenchmarkRunInput
22+
from utils import SingleBenchmarkRunOutput
23+
from utils import run_benchmarks
24+
from utils import run_memory_benchmark
25+
from utils import run_speed_benchmark
26+
27+
from liger_kernel.ops.fused_moe import LigerFusedMoEFunction
28+
from liger_kernel.utils import get_total_gpu_memory
29+
from liger_kernel.utils import infer_device
30+
31+
device = infer_device()
32+
33+
34+
# ---------------------------------------------------------------------------
35+
# HuggingFace reference: Python loop per expert
36+
# Matches transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock
37+
# ---------------------------------------------------------------------------
38+
39+
40+
def _huggingface_moe_forward(x, gate_up_proj, down_proj, top_k_index, top_k_weights):
41+
T, H = x.shape
42+
E = gate_up_proj.shape[0]
43+
final = torch.zeros_like(x)
44+
with torch.no_grad():
45+
expert_mask = torch.nn.functional.one_hot(top_k_index.long(), num_classes=E)
46+
expert_mask = expert_mask.permute(2, 1, 0)
47+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
48+
for eh in expert_hit:
49+
eidx = eh[0]
50+
top_k_pos, token_idx = torch.where(expert_mask[eidx])
51+
curr = x[token_idx]
52+
gate, up = nn.functional.linear(curr, gate_up_proj[eidx]).chunk(2, dim=-1)
53+
curr = nn.functional.silu(gate) * up
54+
curr = nn.functional.linear(curr, down_proj[eidx])
55+
curr = curr * top_k_weights[token_idx, top_k_pos, None]
56+
final.index_add_(0, token_idx, curr.to(final.dtype))
57+
return final
58+
59+
60+
# Expert counts used in the num_experts sweep (independent of model).
61+
EXPERT_SWEEP_VALUES = [8, 16, 32, 64, 128]
62+
63+
64+
# ---------------------------------------------------------------------------
65+
# Input generation
66+
# ---------------------------------------------------------------------------
67+
68+
69+
def _make_moe_inputs(T, E, H, intermediate_dim, K, dtype, requires_grad=True):
70+
torch.manual_seed(42)
71+
x = torch.randn(T, H, dtype=dtype, device=device, requires_grad=requires_grad)
72+
gate_up_proj = (
73+
torch.randn(E, 2 * intermediate_dim, H, dtype=dtype, device=device, requires_grad=requires_grad) * 0.02
74+
)
75+
down_proj = torch.randn(E, H, intermediate_dim, dtype=dtype, device=device, requires_grad=requires_grad) * 0.02
76+
logits = torch.randn(T, E, device=device)
77+
top_k_index = torch.topk(logits, K, dim=-1).indices.to(torch.int32)
78+
top_k_weights = (
79+
torch.softmax(torch.gather(logits, 1, top_k_index.long()), dim=-1).to(dtype).requires_grad_(requires_grad)
80+
)
81+
return x, gate_up_proj, down_proj, top_k_index, top_k_weights
82+
83+
84+
# ---------------------------------------------------------------------------
85+
# Framework-integrated benchmark functions
86+
# ---------------------------------------------------------------------------
87+
88+
89+
def _setup_fused_moe(input: SingleBenchmarkRunInput):
90+
"""Return (fwd_fn, grad_tensors) for the given provider and config.
91+
92+
extra_benchmark_config keys:
93+
sweep_dim : "T" or "E" — which dim input.x varies
94+
T, E : fixed values for the dimension not being swept (None when swept)
95+
H, intermediate_dim, K : model dimensions
96+
dtype : torch.dtype
97+
"""
98+
cfg = input.extra_benchmark_config
99+
T = int(input.x) if cfg["sweep_dim"] == "T" else cfg["T"]
100+
E = int(input.x) if cfg["sweep_dim"] == "E" else cfg["E"]
101+
H, intermediate_dim, K = cfg["H"], cfg["intermediate_dim"], cfg["K"]
102+
dtype = cfg["dtype"]
103+
104+
x, gup, dn, idx, wts = _make_moe_inputs(T, E, H, intermediate_dim, K, dtype, requires_grad=True)
105+
106+
if input.kernel_provider == "liger":
107+
108+
def fwd_fn():
109+
return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts)
110+
elif input.kernel_provider == "huggingface":
111+
112+
def fwd_fn():
113+
return _huggingface_moe_forward(x, gup, dn, idx, wts)
114+
else:
115+
raise ValueError(f"Unknown provider: {input.kernel_provider}")
116+
117+
return fwd_fn, [x, gup, dn, wts]
118+
119+
120+
def bench_speed_fused_moe(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
121+
fwd_fn, grad_tensors = _setup_fused_moe(input)
122+
return run_speed_benchmark(fwd_fn, input.kernel_operation_mode, grad_tensors)
123+
124+
125+
def bench_memory_fused_moe(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
126+
fwd_fn, _ = _setup_fused_moe(input)
127+
return run_memory_benchmark(fwd_fn, input.kernel_operation_mode)
128+
129+
130+
# ---------------------------------------------------------------------------
131+
# Autotune warmup
132+
# ---------------------------------------------------------------------------
133+
134+
135+
def _warmup_liger(T, E, H, intermediate_dim, K, dtype, sweep_dim):
136+
"""Run one full fwd+bwd to exhaust Triton autotune for (H, intermediate_dim).
137+
138+
Triton autotune key is (H_dim, I_dim), so a single call is sufficient to
139+
cache the best config for all subsequent calls with the same H and intermediate_dim.
140+
For the num_experts sweep we also call this once per E value to warm up
141+
CUDA caches for each expert count before do_bench starts timing.
142+
"""
143+
warmup_input = SingleBenchmarkRunInput(
144+
x=T if sweep_dim == "T" else E,
145+
kernel_provider="liger",
146+
extra_benchmark_config={
147+
"sweep_dim": sweep_dim,
148+
"T": T,
149+
"E": E,
150+
"H": H,
151+
"intermediate_dim": intermediate_dim,
152+
"K": K,
153+
"dtype": dtype,
154+
},
155+
)
156+
warmup_fn, _ = _setup_fused_moe(warmup_input)
157+
warmup_out = warmup_fn()
158+
warmup_out.sum().backward()
159+
del warmup_out
160+
torch.cuda.synchronize()
161+
162+
163+
# ---------------------------------------------------------------------------
164+
# Main
165+
# ---------------------------------------------------------------------------
166+
167+
168+
if __name__ == "__main__":
169+
parser = argparse.ArgumentParser(description="Benchmark LigerFusedMoEFunction")
170+
parser.add_argument(
171+
"--overwrite",
172+
action="store_true",
173+
help="Overwrite existing CSV benchmark data",
174+
)
175+
parser.add_argument(
176+
"--sweep-dim",
177+
choices=["num_tokens", "num_experts"],
178+
default="num_tokens",
179+
help="Dimension to sweep (default: num_tokens)",
180+
)
181+
args = parser.parse_args()
182+
183+
moe_cfg = QWEN3_MOE_30B
184+
E = moe_cfg.E
185+
H = moe_cfg.H
186+
intermediate_dim = moe_cfg.intermediate_dim
187+
K = moe_cfg.K
188+
probe_T = moe_cfg.T # representative token count for probing and warmup
189+
dtype = torch.bfloat16
190+
191+
print(
192+
f"Model: {moe_cfg.name} — E={E}, H={H}, intermediate_dim={intermediate_dim}, K={K}, "
193+
f"T_base={probe_T}, dtype={dtype}"
194+
)
195+
196+
# Memory probe using huggingface (no Triton, higher footprint = safe upper bound).
197+
def _probe():
198+
probe_input = SingleBenchmarkRunInput(
199+
x=probe_T,
200+
kernel_provider="huggingface",
201+
extra_benchmark_config={
202+
"sweep_dim": "T",
203+
"T": None,
204+
"E": E,
205+
"H": H,
206+
"intermediate_dim": intermediate_dim,
207+
"K": K,
208+
"dtype": dtype,
209+
},
210+
)
211+
fwd_fn, _ = _setup_fused_moe(probe_input)
212+
return fwd_fn()
213+
214+
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
215+
kernel_bpt = peak_bytes // probe_T
216+
217+
# Pre-warm Liger's Triton autotune before benchmarks start.
218+
#
219+
# Autotune key is (H_dim, I_dim) — one warmup per (H, intermediate_dim) pair is sufficient
220+
# to cache the best config for the entire sweep.
221+
#
222+
# For num_tokens sweep: one pass with the model's base T is enough.
223+
# For num_experts sweep: one pass per E value in EXPERT_SWEEP_VALUES to also
224+
# warm up CUDA caches for each expert count, since weight tensor sizes differ.
225+
print(f"Pre-warming Liger autotune (H={H}, intermediate_dim={intermediate_dim})...")
226+
227+
if args.sweep_dim == "num_tokens":
228+
_warmup_liger(probe_T, E, H, intermediate_dim, K, dtype, sweep_dim="T")
229+
else: # num_experts
230+
for e_val in EXPERT_SWEEP_VALUES:
231+
print(f" warmup E={e_val}...")
232+
_warmup_liger(probe_T, e_val, H, intermediate_dim, K, dtype, sweep_dim="E")
233+
234+
torch.cuda.synchronize()
235+
print("Autotune warmup complete.\n")
236+
237+
if args.sweep_dim == "num_tokens":
238+
# Derive a memory-safe upper bound for T from the probe measurement.
239+
# Target 40% GPU memory utilisation to leave headroom for framework overhead.
240+
usable_bytes = get_total_gpu_memory() * (1024**3) * 0.4
241+
max_T = min(32768, max(256, int(usable_bytes / kernel_bpt)))
242+
# Round down to nearest power-of-two for clean x-axis values.
243+
max_T = 2 ** int(math.log2(max_T)) if max_T >= 256 else 256
244+
x_values = [2**i for i in range(7, int(math.log2(max_T)) + 1)]
245+
extra_configs = [
246+
{
247+
"sweep_dim": "T",
248+
"T": None, # varied by framework
249+
"E": E,
250+
"H": H,
251+
"intermediate_dim": intermediate_dim,
252+
"K": K,
253+
"dtype": dtype,
254+
}
255+
]
256+
x_name, x_label = "T", "num_tokens"
257+
else: # num_experts
258+
x_values = EXPERT_SWEEP_VALUES
259+
extra_configs = [
260+
{
261+
"sweep_dim": "E",
262+
"T": probe_T, # fixed at model's base token count
263+
"E": None, # varied by framework
264+
"H": H,
265+
"intermediate_dim": intermediate_dim,
266+
"K": K,
267+
"dtype": dtype,
268+
}
269+
]
270+
x_name, x_label = "E", "num_experts"
271+
272+
common_configs = {
273+
"kernel_name": "fused_moe",
274+
"x_name": x_name,
275+
"x_label": x_label,
276+
"x_values": x_values,
277+
"kernel_providers": ["liger", "huggingface"],
278+
"extra_benchmark_configs": extra_configs,
279+
"overwrite": args.overwrite,
280+
}
281+
282+
run_benchmarks(
283+
bench_test_fn=bench_speed_fused_moe,
284+
kernel_operation_modes=["full", "forward", "backward"],
285+
metric_name="speed",
286+
metric_unit="ms",
287+
**common_configs,
288+
)
289+
run_benchmarks(
290+
bench_test_fn=bench_memory_fused_moe,
291+
kernel_operation_modes=["full", "forward", "backward"],
292+
metric_name="memory",
293+
metric_unit="MB",
294+
**common_configs,
295+
)

benchmark/scripts/benchmark_model_configs.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,35 @@ class ModelConfigSweepConfig:
9191
seq_len: int
9292

9393

94-
# ── Model Profiles ──────────────────────────────────────────────────────────
94+
@dataclass(frozen=True)
95+
class MoEModelConfig:
96+
"""MoE model architecture profile for fused MoE benchmarks.
97+
98+
EP-adjusted values should be baked in: T = total_tokens / ep_size,
99+
E = total_experts / ep_size.
100+
"""
101+
102+
name: str
103+
T: int # tokens per GPU (EP-adjusted)
104+
E: int # experts per GPU (EP-adjusted)
105+
H: int # hidden size
106+
intermediate_dim: int # expert intermediate size
107+
K: int # top-k
108+
109+
110+
# ── MoE Model Profiles ───────────────────────────────────────────────────────
111+
112+
QWEN3_MOE_30B = MoEModelConfig(
113+
name="qwen3_moe_30b",
114+
T=8192,
115+
E=128,
116+
H=2048,
117+
intermediate_dim=768,
118+
K=8,
119+
)
120+
121+
122+
# ── Dense Model Profiles ─────────────────────────────────────────────────────
95123

96124
LLAMA_2_7B = ModelConfig(
97125
name="llama_2_7b",

src/liger_kernel/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
4848
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
4949
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
50+
from liger_kernel.ops.fused_moe import LigerFusedMoEFunction # noqa: F401
5051
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
5152
from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
5253
from liger_kernel.ops.geglu import geglu_backward # noqa: F401

0 commit comments

Comments
 (0)