@@ -1054,11 +1054,6 @@ def apply(
10541054 else :
10551055 x_2d = x_2d .contiguous ()
10561056
1057- output = torch .zeros (
1058- (x_2d .shape [0 ], layer .hidden_size ),
1059- device = x_2d .device ,
1060- dtype = torch .float32 ,
1061- )
10621057 topk_ids = topk_ids .to (torch .long )
10631058 topk_weights = topk_weights .to (torch .float16 )
10641059 total_assignments = x_2d .shape [0 ] * topk_ids .shape [- 1 ]
@@ -1083,6 +1078,11 @@ def apply(
10831078 x .shape [:- 1 ],
10841079 )
10851080
1081+ output = torch .zeros (
1082+ (x_2d .shape [0 ], layer .hidden_size ),
1083+ device = x_2d .device ,
1084+ dtype = torch .float32 ,
1085+ )
10861086 flat_expert = topk_ids .reshape (- 1 )
10871087 flat_weight = topk_weights .reshape (- 1 )
10881088 flat_token = torch .arange (x_2d .shape [0 ], device = x_2d .device )
@@ -1195,41 +1195,69 @@ def _apply_single_token(
11951195 ) -> torch .Tensor :
11961196 x_3d = x_2d .unsqueeze (0 )
11971197
1198- ops .exl3_mgemm (
1199- x_3d ,
1200- layer .exl3_gate_ptrs_trellis ,
1198+ if layer .exl3_fuse_gate_up :
1199+ ops .make_gate_up_indices (
1200+ layer .exl3_small_gate_up_ids ,
1201+ topk_ids ,
1202+ layer .local_num_experts ,
1203+ )
1204+ ops .exl3_mgemm (
1205+ x_3d ,
1206+ layer .exl3_gate_up_ptrs_trellis ,
1207+ layer .exl3_small_interm_gu ,
1208+ layer .exl3_gate_up_ptrs_suh ,
1209+ layer .exl3_small_yh_gu ,
1210+ layer .exl3_gate_up_ptrs_svh ,
1211+ layer .exl3_small_gate_up_ids ,
1212+ None ,
1213+ layer .exl3_moe_k_gate ,
1214+ - 1 ,
1215+ layer .exl3_gate_mcg ,
1216+ layer .exl3_gate_mul1 ,
1217+ - 1 ,
1218+ - 1 ,
1219+ 0 ,
1220+ )
1221+ else :
1222+ ops .exl3_mgemm (
1223+ x_3d ,
1224+ layer .exl3_gate_ptrs_trellis ,
1225+ layer .exl3_small_interm_g ,
1226+ layer .exl3_gate_ptrs_suh ,
1227+ layer .exl3_small_yh ,
1228+ layer .exl3_gate_ptrs_svh ,
1229+ topk_ids ,
1230+ None ,
1231+ layer .exl3_moe_k_gate ,
1232+ - 1 ,
1233+ layer .exl3_gate_mcg ,
1234+ layer .exl3_gate_mul1 ,
1235+ - 1 ,
1236+ - 1 ,
1237+ 0 ,
1238+ )
1239+ ops .exl3_mgemm (
1240+ x_3d ,
1241+ layer .exl3_up_ptrs_trellis ,
1242+ layer .exl3_small_interm_u ,
1243+ layer .exl3_up_ptrs_suh ,
1244+ layer .exl3_small_yh ,
1245+ layer .exl3_up_ptrs_svh ,
1246+ topk_ids ,
1247+ None ,
1248+ layer .exl3_moe_k_up ,
1249+ - 1 ,
1250+ layer .exl3_up_mcg ,
1251+ layer .exl3_up_mul1 ,
1252+ - 1 ,
1253+ - 1 ,
1254+ 0 ,
1255+ )
1256+ ops .silu_mul (
1257+ layer .exl3_small_interm_a ,
12011258 layer .exl3_small_interm_g ,
1202- layer .exl3_gate_ptrs_suh ,
1203- layer .exl3_small_yh ,
1204- layer .exl3_gate_ptrs_svh ,
1205- topk_ids ,
1206- None ,
1207- layer .exl3_moe_k_gate ,
1208- - 1 ,
1209- layer .exl3_gate_mcg ,
1210- layer .exl3_gate_mul1 ,
1211- - 1 ,
1212- - 1 ,
1213- 0 ,
1214- )
1215- ops .exl3_mgemm (
1216- x_3d ,
1217- layer .exl3_up_ptrs_trellis ,
12181259 layer .exl3_small_interm_u ,
1219- layer .exl3_up_ptrs_suh ,
1220- layer .exl3_small_yh ,
1221- layer .exl3_up_ptrs_svh ,
1222- topk_ids ,
1223- None ,
1224- layer .exl3_moe_k_up ,
1225- - 1 ,
1226- layer .exl3_up_mcg ,
1227- layer .exl3_up_mul1 ,
1228- - 1 ,
1229- - 1 ,
1230- 0 ,
12311260 )
1232- layer .exl3_small_interm_a .copy_ (torch .nn .functional .silu (layer .exl3_small_interm_g ) * layer .exl3_small_interm_u )
12331261 ops .exl3_mgemm (
12341262 layer .exl3_small_interm_a ,
12351263 layer .exl3_down_ptrs_trellis ,
@@ -1261,52 +1289,74 @@ def _apply_small_batch(
12611289 original_dtype : torch .dtype ,
12621290 original_shape : tuple [int , ...],
12631291 ) -> torch .Tensor :
1264- output = torch .empty (
1265- (x_2d .shape [0 ], layer .hidden_size ),
1266- device = x_2d .device ,
1267- dtype = torch .float32 ,
1268- )
1292+ output = layer .exl3_small_batch_out [: x_2d .shape [0 ]]
12691293 x_3d = x_2d .unsqueeze (1 ).unsqueeze (1 )
12701294 topk_ids_3d = topk_ids .unsqueeze (1 )
12711295 topk_weights_3d = topk_weights .unsqueeze (1 )
12721296
12731297 for i in range (x_2d .shape [0 ]):
1274- ops .exl3_mgemm (
1275- x_3d [i ],
1276- layer .exl3_gate_ptrs_trellis ,
1298+ if layer .exl3_fuse_gate_up :
1299+ ops .make_gate_up_indices (
1300+ layer .exl3_small_gate_up_ids ,
1301+ topk_ids_3d [i ],
1302+ layer .local_num_experts ,
1303+ )
1304+ ops .exl3_mgemm (
1305+ x_3d [i ],
1306+ layer .exl3_gate_up_ptrs_trellis ,
1307+ layer .exl3_small_interm_gu ,
1308+ layer .exl3_gate_up_ptrs_suh ,
1309+ layer .exl3_small_yh_gu ,
1310+ layer .exl3_gate_up_ptrs_svh ,
1311+ layer .exl3_small_gate_up_ids ,
1312+ None ,
1313+ layer .exl3_moe_k_gate ,
1314+ - 1 ,
1315+ layer .exl3_gate_mcg ,
1316+ layer .exl3_gate_mul1 ,
1317+ - 1 ,
1318+ - 1 ,
1319+ 0 ,
1320+ )
1321+ else :
1322+ ops .exl3_mgemm (
1323+ x_3d [i ],
1324+ layer .exl3_gate_ptrs_trellis ,
1325+ layer .exl3_small_interm_g ,
1326+ layer .exl3_gate_ptrs_suh ,
1327+ layer .exl3_small_yh ,
1328+ layer .exl3_gate_ptrs_svh ,
1329+ topk_ids_3d [i ],
1330+ None ,
1331+ layer .exl3_moe_k_gate ,
1332+ - 1 ,
1333+ layer .exl3_gate_mcg ,
1334+ layer .exl3_gate_mul1 ,
1335+ - 1 ,
1336+ - 1 ,
1337+ 0 ,
1338+ )
1339+ ops .exl3_mgemm (
1340+ x_3d [i ],
1341+ layer .exl3_up_ptrs_trellis ,
1342+ layer .exl3_small_interm_u ,
1343+ layer .exl3_up_ptrs_suh ,
1344+ layer .exl3_small_yh ,
1345+ layer .exl3_up_ptrs_svh ,
1346+ topk_ids_3d [i ],
1347+ None ,
1348+ layer .exl3_moe_k_up ,
1349+ - 1 ,
1350+ layer .exl3_up_mcg ,
1351+ layer .exl3_up_mul1 ,
1352+ - 1 ,
1353+ - 1 ,
1354+ 0 ,
1355+ )
1356+ ops .silu_mul (
1357+ layer .exl3_small_interm_a ,
12771358 layer .exl3_small_interm_g ,
1278- layer .exl3_gate_ptrs_suh ,
1279- layer .exl3_small_yh ,
1280- layer .exl3_gate_ptrs_svh ,
1281- topk_ids_3d [i ],
1282- None ,
1283- layer .exl3_moe_k_gate ,
1284- - 1 ,
1285- layer .exl3_gate_mcg ,
1286- layer .exl3_gate_mul1 ,
1287- - 1 ,
1288- - 1 ,
1289- 0 ,
1290- )
1291- ops .exl3_mgemm (
1292- x_3d [i ],
1293- layer .exl3_up_ptrs_trellis ,
12941359 layer .exl3_small_interm_u ,
1295- layer .exl3_up_ptrs_suh ,
1296- layer .exl3_small_yh ,
1297- layer .exl3_up_ptrs_svh ,
1298- topk_ids_3d [i ],
1299- None ,
1300- layer .exl3_moe_k_up ,
1301- - 1 ,
1302- layer .exl3_up_mcg ,
1303- layer .exl3_up_mul1 ,
1304- - 1 ,
1305- - 1 ,
1306- 0 ,
1307- )
1308- layer .exl3_small_interm_a .copy_ (
1309- torch .nn .functional .silu (layer .exl3_small_interm_g ) * layer .exl3_small_interm_u
13101360 )
13111361 ops .exl3_mgemm (
13121362 layer .exl3_small_interm_a ,
@@ -1357,6 +1407,9 @@ def ptr_tensor(prefix: str, attr: str, shard_id: str):
13571407 layer .exl3_up_ptrs_trellis = ptr_tensor ("w13" , "trellis" , "w3" )
13581408 layer .exl3_up_ptrs_suh = ptr_tensor ("w13" , "suh" , "w3" )
13591409 layer .exl3_up_ptrs_svh = ptr_tensor ("w13" , "svh" , "w3" )
1410+ layer .exl3_gate_up_ptrs_trellis = torch .cat ([layer .exl3_gate_ptrs_trellis , layer .exl3_up_ptrs_trellis ])
1411+ layer .exl3_gate_up_ptrs_suh = torch .cat ([layer .exl3_gate_ptrs_suh , layer .exl3_up_ptrs_suh ])
1412+ layer .exl3_gate_up_ptrs_svh = torch .cat ([layer .exl3_gate_ptrs_svh , layer .exl3_up_ptrs_svh ])
13601413 layer .exl3_down_ptrs_trellis = ptr_tensor ("w2" , "trellis" , "w2" )
13611414 layer .exl3_down_ptrs_suh = ptr_tensor ("w2" , "suh" , "w2" )
13621415 layer .exl3_down_ptrs_svh = ptr_tensor ("w2" , "svh" , "w2" )
@@ -1376,16 +1429,33 @@ def ptr_tensor(prefix: str, attr: str, shard_id: str):
13761429 layer .exl3_up_mul1 = (0 , "w3" ) in layer .w13_mul1 .exl3_tensors
13771430 layer .exl3_down_mcg = (0 , "w2" ) in layer .w2_mcg .exl3_tensors
13781431 layer .exl3_down_mul1 = (0 , "w2" ) in layer .w2_mul1 .exl3_tensors
1432+ layer .exl3_fuse_gate_up = (
1433+ layer .exl3_moe_k_gate == layer .exl3_moe_k_up
1434+ and layer .exl3_gate_mcg == layer .exl3_up_mcg
1435+ and layer .exl3_gate_mul1 == layer .exl3_up_mul1
1436+ )
13791437
13801438 layer .exl3_small_batch_threshold = min (
13811439 layer .local_num_experts // layer .top_k ,
13821440 _EXL3_MOE_MAX_EXPERTS_PER_TOKEN ,
13831441 )
1384- layer .exl3_small_yh = torch .empty ((layer .top_k , 1 , layer .hidden_size ), dtype = torch .float16 , device = device )
1385- layer .exl3_small_interm_g = torch .empty ((layer .top_k , 1 , intermediate_size ), dtype = torch .float16 , device = device )
1386- layer .exl3_small_interm_u = torch .empty ((layer .top_k , 1 , intermediate_size ), dtype = torch .float16 , device = device )
1442+ layer .exl3_small_yh_gu = torch .empty (
1443+ (layer .top_k * 2 , 1 , layer .hidden_size ), dtype = torch .float16 , device = device
1444+ )
1445+ layer .exl3_small_interm_gu = torch .empty (
1446+ (layer .top_k * 2 , 1 , intermediate_size ), dtype = torch .float16 , device = device
1447+ )
1448+ layer .exl3_small_yh = layer .exl3_small_yh_gu [: layer .top_k ]
1449+ layer .exl3_small_interm_g = layer .exl3_small_interm_gu [: layer .top_k ]
1450+ layer .exl3_small_interm_u = layer .exl3_small_interm_gu [layer .top_k :]
1451+ layer .exl3_small_gate_up_ids = torch .empty ((1 , layer .top_k * 2 ), dtype = torch .long , device = device )
13871452 layer .exl3_small_interm_a = torch .empty ((layer .top_k , 1 , intermediate_size ), dtype = torch .float16 , device = device )
13881453 layer .exl3_small_out_d = torch .empty ((layer .top_k , 1 , layer .hidden_size ), dtype = torch .float32 , device = device )
1454+ layer .exl3_small_batch_out = torch .empty (
1455+ (layer .exl3_small_batch_threshold , layer .hidden_size ),
1456+ dtype = torch .float32 ,
1457+ device = device ,
1458+ )
13891459
13901460 concurrency = max (
13911461 1 ,
0 commit comments