-
Notifications
You must be signed in to change notification settings - Fork 131
Expand file tree
/
Copy pathMoE.py
More file actions
467 lines (383 loc) · 14.5 KB
/
MoE.py
File metadata and controls
467 lines (383 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import argparse
import math
import torch
import torch.nn.functional as F
import cuda.tile as ct
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel
def fused_moe_kernel(
A,
B,
C,
topk_weights,
sorted_token_ids,
sorted_expert_ids,
num_token_replicas: int,
mul_routed_weight: ConstBool,
TILE_M: ConstInt,
TILE_N: ConstInt,
TILE_K: ConstInt,
):
"""
Fused MoE kernel that multiplies tokens by their assigned expert weights.
Args:
A: Input tokens, shape (batch, K).
B: Expert weights, shape (num_experts, N, K).
C: Output tensor, shape (num_tokens * topk, N).
topk_weights: Router weights for each token-expert pair, shape (num_tokens * topk,).
sorted_token_ids: Token indices sorted by expert assignment, replicated topk times,
and padded to align with TILE_M.
sorted_expert_ids: Expert index for each TILE_M, sorted.
num_token_replicas: Replication factor applied to each token row in A (topk or 1).
mul_routed_weight: Whether to multiply output by router weights.
Token ids are sorted and padded to ensure each expert processes a multiple of TILE_M tokens,
enabling efficient tiled matrix multiplication.
"""
M = sorted_token_ids.shape[0]
N = B.shape[1]
K = B.shape[2]
GROUP_SIZE_M = 8
bid_m, bid_n = swizzle_2d(M, N, TILE_M, TILE_N, GROUP_SIZE_M)
zero_pad = ct.PaddingMode.ZERO
# Gather replicated/padded token indices handled by this block pair (bid_m, bid_n).
token_id_indices = bid_m * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
token_ids = ct.gather(sorted_token_ids, token_id_indices)
# Collapse the replica dimension to recover the source row in A for each entry.
a_row_indices = token_ids // num_token_replicas
# Each TILE_M block is homogenous in expert assignment; fetch the expert id once.
expert_id = ct.load(sorted_expert_ids, index=bid_m, shape=())
accumulator = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
for k in range(0, ct.cdiv(K, TILE_K)):
a_col_indices = k * TILE_K + ct.arange(TILE_K, dtype=ct.int32)
a = ct.gather(A, (a_row_indices[:, None], a_col_indices[None, :]))
b = ct.load(B, (expert_id, k, bid_n), shape=(1, TILE_K, TILE_N),
order=(0, 2, 1), padding_mode=zero_pad).reshape((TILE_K, TILE_N))
accumulator = ct.mma(a, b, accumulator)
if mul_routed_weight:
moe_weight = ct.gather(topk_weights, token_ids)
accumulator = accumulator * moe_weight[:, None]
# Compute the column span this block covers and scatter the tile back into C.
c_col_indices = bid_n * TILE_N + ct.arange(TILE_N, dtype=ct.int32)
accumulator = ct.astype(accumulator, C.dtype)
ct.scatter(C, (token_ids[:, None], c_col_indices[None, :]), accumulator)
@ct.kernel
def silu_and_mul_kernel(A, B, C, TILE_N: ConstInt):
"""
Element-wise kernel that computes SiLU(A) * B.
Args:
A: Input tensor A.
B: Input tensor B.
C: Output tensor.
"""
bid_m = ct.bid(0)
ta = ct.load(A, (bid_m, 0), (1, TILE_N)).astype(ct.float32)
tb = ct.load(B, (bid_m, 0), (1, TILE_N)).astype(ct.float32)
# Sigmoid(ta)
denom = ct.add(1, ct.exp(-ta), flush_to_zero=True)
sigmoid_ta = ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=ct.RoundingMode.APPROX)
# SiLU(ta) * tb
silu_ta = ct.mul(ta, sigmoid_ta, flush_to_zero=True)
tc = ct.mul(silu_ta, tb, flush_to_zero=True)
ct.store(C, (bid_m, 0), tc.astype(C.dtype))
def moe_align_tile_size_torch(
topk_ids: torch.Tensor, tile_m: int, num_experts: int
):
"""
Sort, replicate, and pad token indices by expert so every expert processes a
TILE_M-aligned tile when launching the fused_moe_kernel.
Args:
topk_ids: Router-selected expert ids per token (num_tokens, topk).
tile_m: Tile size used along the M dimension by the kernel.
num_experts: Total number of experts present in w1/w2 tensors.
Returns:
sorted_token_ids: 1-D tensor containing the flattened token-replica indices
sorted by expert; remaining slots are filled with a sentinel index
(num_tokens * topk) for padding.
sorted_expert_ids: For each block, the expert id that
owns the corresponding TILE_M slice in `sorted_token_ids`.
"""
device = topk_ids.device
num_tokens, topk = topk_ids.shape
total_tokens = num_tokens * topk
# Flatten expert ids (num_tokens * topk) and sort by experts.
flat_expert_ids = topk_ids.reshape(-1)
sorted_token_indices = torch.argsort(flat_expert_ids, stable=True)
# Determine how many replicas each expert owns and how many TILE_M blocks we need
# once padded to TILE_M alignment.
expert_token_counts = torch.bincount(flat_expert_ids, minlength=num_experts)
expert_block_counts = (expert_token_counts - 1 + tile_m) // tile_m
total_blocks = expert_block_counts.sum()
# Allocate output buffers; fill token ids with sentinel value (total_tokens).
sorted_token_ids = torch.full((total_blocks * tile_m,), total_tokens,
device=device, dtype=torch.int32)
sorted_expert_ids = torch.zeros((total_blocks,), device=device,
dtype=torch.int32)
current_block = 0
current_token = 0
for expert_id in range(num_experts):
token_count = expert_token_counts[expert_id]
block_count = expert_block_counts[expert_id]
# Map each TILE_M block with its owning expert id
sorted_expert_ids[current_block:current_block+block_count] = expert_id
sorted_token_start = current_block * tile_m
# Copy the expert's sorted token indices; residual slots remain at the
# sentinel value for padding.
sorted_token_ids[sorted_token_start:sorted_token_start+token_count] = (
sorted_token_indices[current_token:current_token+token_count]
)
current_token += token_count
current_block += block_count
return sorted_token_ids, sorted_expert_ids
def swizzle_2d_from_bid(M, N, tm, tn, GROUP_SIZE_M, bid):
# Get the global IDs of a given block in a 1D grid.
num_bid_m = ct.cdiv(M, tm)
num_bid_n = ct.cdiv(N, tn)
num_bid_in_group = GROUP_SIZE_M * num_bid_n
group_id = bid // num_bid_in_group
first_bid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M)
bid_m = first_bid_m + (bid % group_size_m)
bid_n = (bid % num_bid_in_group) // group_size_m
return bid_m, bid_n
def swizzle_2d(M, N, tm, tn, GROUP_SIZE_M):
# Get the global IDs of the current block in a 1D grid.
bid = ct.bid(0)
return swizzle_2d_from_bid(M, N, tm, tn, GROUP_SIZE_M, bid)
# --- cuTile MoE Wrapper ------------------------------------------------------
def cutile_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
tile_m: int,
tile_n: int,
tile_k: int,
) -> torch.Tensor:
"""
Executes a Mixture-of-Experts (MoE) forward pass using the fused cuTile kernel.
Args:
hidden_states: Token activations, shape (num_tokens, hidden_size)
w1: Expert gate+up projection weights,
shape (num_experts, intermediate_size * 2, hidden_size)
w2: Expert down projection weights,
shape (num_experts, hidden_size, intermediate_size)
topk_weights: Router weights per token, shape (num_tokens, topk)
topk_ids: Expert indices per token, shape (num_tokens, topk)
tile_m/n/k: Tile sizes for cuTile kernel launch
Returns:
Tensor with the same shape/dtype as `hidden_states`.
"""
out_dtype = hidden_states.dtype
device = hidden_states.device
num_tokens, hidden_size = hidden_states.shape
num_experts, _, intermediate_size = w2.shape
_, topk = topk_ids.shape
if w1.shape[1] != intermediate_size * 2:
raise ValueError("w1 must have 2 * intermediate_size rows (gate + up projection)")
intermediate_cache1 = torch.zeros(
(num_tokens, topk, intermediate_size * 2),
device=device,
dtype=out_dtype,
)
intermediate_cache2 = torch.zeros(
(num_tokens * topk, intermediate_size),
device=device,
dtype=out_dtype,
)
intermediate_cache3 = torch.zeros(
(num_tokens, topk, hidden_size),
device=device,
dtype=out_dtype,
)
sorted_token_ids, sorted_expert_ids = moe_align_tile_size_torch(
topk_ids,
tile_m,
num_experts,
)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
topk_weights,
sorted_token_ids,
sorted_expert_ids,
mul_routed_weight=False,
num_token_replicas=topk,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
)
invoke_silu_and_mul_kernel(
intermediate_cache1.view(-1, intermediate_cache1.shape[-1]),
intermediate_cache2,
)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
topk_weights,
sorted_token_ids,
sorted_expert_ids,
mul_routed_weight=True,
num_token_replicas=1,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
)
return torch.sum(intermediate_cache3, dim=1)
# --- Torch Reference Implementation -----------------------------------------
def torch_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
"""
Naive torch implementation of MoE for correctness checks.
"""
gate_proj, up_proj = w1.chunk(2, dim=1)
down_proj = w2
num_experts = w1.shape[0]
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = F.one_hot(topk_ids, num_classes=num_experts).permute(2, 1, 0)
expert_usage = expert_mask.sum(dim=(-1, -2)) > 0
active_expert_ids = expert_usage.nonzero().squeeze(-1)
for expert_id in active_expert_ids:
expert_gate = gate_proj[expert_id]
expert_up = up_proj[expert_id]
expert_down = down_proj[expert_id]
matched_ks, matched_token_ids = torch.where(expert_mask[expert_id])
matched_tokens = hidden_states[matched_token_ids]
gate_output = matched_tokens @ expert_gate.T
up_output = matched_tokens @ expert_up.T
swiglu_output = F.silu(gate_output) * up_output
expert_output = swiglu_output @ expert_down.T
routing_weights = topk_weights[matched_token_ids, matched_ks]
weighted_output = expert_output * routing_weights.unsqueeze(-1)
final_hidden_states.index_add_(
0,
matched_token_ids,
weighted_output.to(hidden_states.dtype),
)
return final_hidden_states
# --- Helper Utilities -------------------------------------------------------
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_expert_ids: torch.Tensor,
mul_routed_weight: bool,
num_token_replicas: int,
tile_m: int,
tile_n: int,
tile_k: int,
) -> None:
m = sorted_token_ids.shape[0]
n = B.shape[1]
grid = (math.ceil(m / tile_m) * math.ceil(n / tile_n),)
topk_weights = topk_weights.view(-1)
C = C.view(-1, C.shape[2])
ct.launch(
torch.cuda.current_stream(),
grid,
fused_moe_kernel,
(
A,
B,
C,
topk_weights,
sorted_token_ids,
sorted_expert_ids,
num_token_replicas,
mul_routed_weight,
tile_m,
tile_n,
tile_k,
),
)
def invoke_silu_and_mul_kernel(
AB: torch.Tensor,
C: torch.Tensor
):
A, B = AB.chunk(2, dim=-1)
ct.launch(
torch.cuda.current_stream(),
(AB.shape[0],),
silu_and_mul_kernel,
(
A,
B,
C,
next_power_of_2(C.shape[-1])
)
)
def next_power_of_2(n: int):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--correctness-check",
action="store_true",
help="Check the correctness of the cuTile MoE output against a torch reference.",
)
args = parser.parse_args()
print("--- Running cuTile Mixture-of-Experts (MoE) Sample ---")
num_tokens = 48
hidden_size = 512
num_experts = 64
intermediate_size = 1024
topk = 8
dtype = torch.bfloat16
device = "cuda"
print(
f"Tokens: {num_tokens}, Hidden: {hidden_size}, "
f"Experts: {num_experts}, Intermediate: {intermediate_size}, "
f"TopK: {topk}, dtype: {dtype}"
)
hidden_states = torch.empty(
num_tokens, hidden_size, device=device, dtype=dtype
).normal_(0, 0.5)
w1 = torch.empty(
num_experts, intermediate_size * 2, hidden_size, device=device, dtype=dtype
).normal_(0, 0.1)
w2 = torch.empty(
num_experts, hidden_size, intermediate_size, device=device, dtype=dtype
).normal_(0, 0.1)
# Unique expert IDs for each token (no repeating elements per row)
topk_ids = torch.stack([
torch.randperm(num_experts, device=device)[:topk]
for _ in range(num_tokens)
])
topk_weights = torch.softmax(
torch.randn(num_tokens, topk, device=device), dim=-1
).to(dtype)
print("\n--- Executing cuTile MoE ---")
output_cutile = cutile_moe(hidden_states, w1, w2, topk_weights, topk_ids,
tile_m=128, tile_n=128, tile_k=64)
print(f"cuTile MoE output shape: {output_cutile.shape}, "
"dtype: {output_cutile.dtype}")
if args.correctness_check:
print("\n--- Running correctness check against torch reference ---")
ref_output = torch_moe(hidden_states, w1, w2, topk_weights, topk_ids)
torch.testing.assert_close(output_cutile, ref_output, rtol=1e-1, atol=1e-1)
print("Correctness check passed")
else:
print("Correctness check disabled (use --correctness-check to enable)")
print("\n--- cuTile Mixture-of-Experts (MoE) Sample complete ---")