Skip to content

Commit 3917cd8

Browse files
authored
perf: EXL3 performance tuning on GeForce Blackwell (#1652)
* perf: tune EXL3 GEMM selection for Blackwell Signed-off-by: AlpinDale <alpindale@gmail.com> * perf: expand EXL3 Blackwell GEMM overrides Signed-off-by: AlpinDale <alpindale@gmail.com> * perf: tune EXL3 MoE mgemm on Blackwell Signed-off-by: AlpinDale <alpindale@gmail.com> * perf: scope EXL3 Blackwell tuning to MoE mgemm Signed-off-by: AlpinDale <alpindale@gmail.com> * perf: tune EXL3 dense mgemm on Blackwell Signed-off-by: AlpinDale <alpindale@gmail.com> * perf: reduce EXL3 MoE decode overhead Signed-off-by: AlpinDale <alpindale@gmail.com> * perf: fuse EXL3 MoE gate up projection Signed-off-by: AlpinDale <alpindale@gmail.com> --------- Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 61aad7c commit 3917cd8

7 files changed

Lines changed: 383 additions & 84 deletions

File tree

aphrodite/_custom_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,18 @@ def silu_and_mul_per_block_quant(
860860
return output, scales
861861

862862

863+
def silu_mul(out: torch.Tensor, gate: torch.Tensor, up: torch.Tensor) -> None:
864+
torch.ops._C.silu_mul(out, gate, up)
865+
866+
867+
def make_gate_up_indices(
868+
out: torch.Tensor,
869+
indices: torch.Tensor,
870+
offset: int,
871+
) -> None:
872+
torch.ops._C.make_gate_up_indices(out, indices, offset)
873+
874+
863875
# quantization ops
864876
# awq
865877
def awq_dequantize(

aphrodite/model_executor/layers/quantization/exl3.py

Lines changed: 149 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,11 +1054,6 @@ def apply(
10541054
else:
10551055
x_2d = x_2d.contiguous()
10561056

1057-
output = torch.zeros(
1058-
(x_2d.shape[0], layer.hidden_size),
1059-
device=x_2d.device,
1060-
dtype=torch.float32,
1061-
)
10621057
topk_ids = topk_ids.to(torch.long)
10631058
topk_weights = topk_weights.to(torch.float16)
10641059
total_assignments = x_2d.shape[0] * topk_ids.shape[-1]
@@ -1083,6 +1078,11 @@ def apply(
10831078
x.shape[:-1],
10841079
)
10851080

1081+
output = torch.zeros(
1082+
(x_2d.shape[0], layer.hidden_size),
1083+
device=x_2d.device,
1084+
dtype=torch.float32,
1085+
)
10861086
flat_expert = topk_ids.reshape(-1)
10871087
flat_weight = topk_weights.reshape(-1)
10881088
flat_token = torch.arange(x_2d.shape[0], device=x_2d.device)
@@ -1195,41 +1195,69 @@ def _apply_single_token(
11951195
) -> torch.Tensor:
11961196
x_3d = x_2d.unsqueeze(0)
11971197

1198-
ops.exl3_mgemm(
1199-
x_3d,
1200-
layer.exl3_gate_ptrs_trellis,
1198+
if layer.exl3_fuse_gate_up:
1199+
ops.make_gate_up_indices(
1200+
layer.exl3_small_gate_up_ids,
1201+
topk_ids,
1202+
layer.local_num_experts,
1203+
)
1204+
ops.exl3_mgemm(
1205+
x_3d,
1206+
layer.exl3_gate_up_ptrs_trellis,
1207+
layer.exl3_small_interm_gu,
1208+
layer.exl3_gate_up_ptrs_suh,
1209+
layer.exl3_small_yh_gu,
1210+
layer.exl3_gate_up_ptrs_svh,
1211+
layer.exl3_small_gate_up_ids,
1212+
None,
1213+
layer.exl3_moe_k_gate,
1214+
-1,
1215+
layer.exl3_gate_mcg,
1216+
layer.exl3_gate_mul1,
1217+
-1,
1218+
-1,
1219+
0,
1220+
)
1221+
else:
1222+
ops.exl3_mgemm(
1223+
x_3d,
1224+
layer.exl3_gate_ptrs_trellis,
1225+
layer.exl3_small_interm_g,
1226+
layer.exl3_gate_ptrs_suh,
1227+
layer.exl3_small_yh,
1228+
layer.exl3_gate_ptrs_svh,
1229+
topk_ids,
1230+
None,
1231+
layer.exl3_moe_k_gate,
1232+
-1,
1233+
layer.exl3_gate_mcg,
1234+
layer.exl3_gate_mul1,
1235+
-1,
1236+
-1,
1237+
0,
1238+
)
1239+
ops.exl3_mgemm(
1240+
x_3d,
1241+
layer.exl3_up_ptrs_trellis,
1242+
layer.exl3_small_interm_u,
1243+
layer.exl3_up_ptrs_suh,
1244+
layer.exl3_small_yh,
1245+
layer.exl3_up_ptrs_svh,
1246+
topk_ids,
1247+
None,
1248+
layer.exl3_moe_k_up,
1249+
-1,
1250+
layer.exl3_up_mcg,
1251+
layer.exl3_up_mul1,
1252+
-1,
1253+
-1,
1254+
0,
1255+
)
1256+
ops.silu_mul(
1257+
layer.exl3_small_interm_a,
12011258
layer.exl3_small_interm_g,
1202-
layer.exl3_gate_ptrs_suh,
1203-
layer.exl3_small_yh,
1204-
layer.exl3_gate_ptrs_svh,
1205-
topk_ids,
1206-
None,
1207-
layer.exl3_moe_k_gate,
1208-
-1,
1209-
layer.exl3_gate_mcg,
1210-
layer.exl3_gate_mul1,
1211-
-1,
1212-
-1,
1213-
0,
1214-
)
1215-
ops.exl3_mgemm(
1216-
x_3d,
1217-
layer.exl3_up_ptrs_trellis,
12181259
layer.exl3_small_interm_u,
1219-
layer.exl3_up_ptrs_suh,
1220-
layer.exl3_small_yh,
1221-
layer.exl3_up_ptrs_svh,
1222-
topk_ids,
1223-
None,
1224-
layer.exl3_moe_k_up,
1225-
-1,
1226-
layer.exl3_up_mcg,
1227-
layer.exl3_up_mul1,
1228-
-1,
1229-
-1,
1230-
0,
12311260
)
1232-
layer.exl3_small_interm_a.copy_(torch.nn.functional.silu(layer.exl3_small_interm_g) * layer.exl3_small_interm_u)
12331261
ops.exl3_mgemm(
12341262
layer.exl3_small_interm_a,
12351263
layer.exl3_down_ptrs_trellis,
@@ -1261,52 +1289,74 @@ def _apply_small_batch(
12611289
original_dtype: torch.dtype,
12621290
original_shape: tuple[int, ...],
12631291
) -> torch.Tensor:
1264-
output = torch.empty(
1265-
(x_2d.shape[0], layer.hidden_size),
1266-
device=x_2d.device,
1267-
dtype=torch.float32,
1268-
)
1292+
output = layer.exl3_small_batch_out[: x_2d.shape[0]]
12691293
x_3d = x_2d.unsqueeze(1).unsqueeze(1)
12701294
topk_ids_3d = topk_ids.unsqueeze(1)
12711295
topk_weights_3d = topk_weights.unsqueeze(1)
12721296

12731297
for i in range(x_2d.shape[0]):
1274-
ops.exl3_mgemm(
1275-
x_3d[i],
1276-
layer.exl3_gate_ptrs_trellis,
1298+
if layer.exl3_fuse_gate_up:
1299+
ops.make_gate_up_indices(
1300+
layer.exl3_small_gate_up_ids,
1301+
topk_ids_3d[i],
1302+
layer.local_num_experts,
1303+
)
1304+
ops.exl3_mgemm(
1305+
x_3d[i],
1306+
layer.exl3_gate_up_ptrs_trellis,
1307+
layer.exl3_small_interm_gu,
1308+
layer.exl3_gate_up_ptrs_suh,
1309+
layer.exl3_small_yh_gu,
1310+
layer.exl3_gate_up_ptrs_svh,
1311+
layer.exl3_small_gate_up_ids,
1312+
None,
1313+
layer.exl3_moe_k_gate,
1314+
-1,
1315+
layer.exl3_gate_mcg,
1316+
layer.exl3_gate_mul1,
1317+
-1,
1318+
-1,
1319+
0,
1320+
)
1321+
else:
1322+
ops.exl3_mgemm(
1323+
x_3d[i],
1324+
layer.exl3_gate_ptrs_trellis,
1325+
layer.exl3_small_interm_g,
1326+
layer.exl3_gate_ptrs_suh,
1327+
layer.exl3_small_yh,
1328+
layer.exl3_gate_ptrs_svh,
1329+
topk_ids_3d[i],
1330+
None,
1331+
layer.exl3_moe_k_gate,
1332+
-1,
1333+
layer.exl3_gate_mcg,
1334+
layer.exl3_gate_mul1,
1335+
-1,
1336+
-1,
1337+
0,
1338+
)
1339+
ops.exl3_mgemm(
1340+
x_3d[i],
1341+
layer.exl3_up_ptrs_trellis,
1342+
layer.exl3_small_interm_u,
1343+
layer.exl3_up_ptrs_suh,
1344+
layer.exl3_small_yh,
1345+
layer.exl3_up_ptrs_svh,
1346+
topk_ids_3d[i],
1347+
None,
1348+
layer.exl3_moe_k_up,
1349+
-1,
1350+
layer.exl3_up_mcg,
1351+
layer.exl3_up_mul1,
1352+
-1,
1353+
-1,
1354+
0,
1355+
)
1356+
ops.silu_mul(
1357+
layer.exl3_small_interm_a,
12771358
layer.exl3_small_interm_g,
1278-
layer.exl3_gate_ptrs_suh,
1279-
layer.exl3_small_yh,
1280-
layer.exl3_gate_ptrs_svh,
1281-
topk_ids_3d[i],
1282-
None,
1283-
layer.exl3_moe_k_gate,
1284-
-1,
1285-
layer.exl3_gate_mcg,
1286-
layer.exl3_gate_mul1,
1287-
-1,
1288-
-1,
1289-
0,
1290-
)
1291-
ops.exl3_mgemm(
1292-
x_3d[i],
1293-
layer.exl3_up_ptrs_trellis,
12941359
layer.exl3_small_interm_u,
1295-
layer.exl3_up_ptrs_suh,
1296-
layer.exl3_small_yh,
1297-
layer.exl3_up_ptrs_svh,
1298-
topk_ids_3d[i],
1299-
None,
1300-
layer.exl3_moe_k_up,
1301-
-1,
1302-
layer.exl3_up_mcg,
1303-
layer.exl3_up_mul1,
1304-
-1,
1305-
-1,
1306-
0,
1307-
)
1308-
layer.exl3_small_interm_a.copy_(
1309-
torch.nn.functional.silu(layer.exl3_small_interm_g) * layer.exl3_small_interm_u
13101360
)
13111361
ops.exl3_mgemm(
13121362
layer.exl3_small_interm_a,
@@ -1357,6 +1407,9 @@ def ptr_tensor(prefix: str, attr: str, shard_id: str):
13571407
layer.exl3_up_ptrs_trellis = ptr_tensor("w13", "trellis", "w3")
13581408
layer.exl3_up_ptrs_suh = ptr_tensor("w13", "suh", "w3")
13591409
layer.exl3_up_ptrs_svh = ptr_tensor("w13", "svh", "w3")
1410+
layer.exl3_gate_up_ptrs_trellis = torch.cat([layer.exl3_gate_ptrs_trellis, layer.exl3_up_ptrs_trellis])
1411+
layer.exl3_gate_up_ptrs_suh = torch.cat([layer.exl3_gate_ptrs_suh, layer.exl3_up_ptrs_suh])
1412+
layer.exl3_gate_up_ptrs_svh = torch.cat([layer.exl3_gate_ptrs_svh, layer.exl3_up_ptrs_svh])
13601413
layer.exl3_down_ptrs_trellis = ptr_tensor("w2", "trellis", "w2")
13611414
layer.exl3_down_ptrs_suh = ptr_tensor("w2", "suh", "w2")
13621415
layer.exl3_down_ptrs_svh = ptr_tensor("w2", "svh", "w2")
@@ -1376,16 +1429,33 @@ def ptr_tensor(prefix: str, attr: str, shard_id: str):
13761429
layer.exl3_up_mul1 = (0, "w3") in layer.w13_mul1.exl3_tensors
13771430
layer.exl3_down_mcg = (0, "w2") in layer.w2_mcg.exl3_tensors
13781431
layer.exl3_down_mul1 = (0, "w2") in layer.w2_mul1.exl3_tensors
1432+
layer.exl3_fuse_gate_up = (
1433+
layer.exl3_moe_k_gate == layer.exl3_moe_k_up
1434+
and layer.exl3_gate_mcg == layer.exl3_up_mcg
1435+
and layer.exl3_gate_mul1 == layer.exl3_up_mul1
1436+
)
13791437

13801438
layer.exl3_small_batch_threshold = min(
13811439
layer.local_num_experts // layer.top_k,
13821440
_EXL3_MOE_MAX_EXPERTS_PER_TOKEN,
13831441
)
1384-
layer.exl3_small_yh = torch.empty((layer.top_k, 1, layer.hidden_size), dtype=torch.float16, device=device)
1385-
layer.exl3_small_interm_g = torch.empty((layer.top_k, 1, intermediate_size), dtype=torch.float16, device=device)
1386-
layer.exl3_small_interm_u = torch.empty((layer.top_k, 1, intermediate_size), dtype=torch.float16, device=device)
1442+
layer.exl3_small_yh_gu = torch.empty(
1443+
(layer.top_k * 2, 1, layer.hidden_size), dtype=torch.float16, device=device
1444+
)
1445+
layer.exl3_small_interm_gu = torch.empty(
1446+
(layer.top_k * 2, 1, intermediate_size), dtype=torch.float16, device=device
1447+
)
1448+
layer.exl3_small_yh = layer.exl3_small_yh_gu[: layer.top_k]
1449+
layer.exl3_small_interm_g = layer.exl3_small_interm_gu[: layer.top_k]
1450+
layer.exl3_small_interm_u = layer.exl3_small_interm_gu[layer.top_k :]
1451+
layer.exl3_small_gate_up_ids = torch.empty((1, layer.top_k * 2), dtype=torch.long, device=device)
13871452
layer.exl3_small_interm_a = torch.empty((layer.top_k, 1, intermediate_size), dtype=torch.float16, device=device)
13881453
layer.exl3_small_out_d = torch.empty((layer.top_k, 1, layer.hidden_size), dtype=torch.float32, device=device)
1454+
layer.exl3_small_batch_out = torch.empty(
1455+
(layer.exl3_small_batch_threshold, layer.hidden_size),
1456+
dtype=torch.float32,
1457+
device=device,
1458+
)
13891459

13901460
concurrency = max(
13911461
1,

0 commit comments

Comments
 (0)