@@ -1330,7 +1330,7 @@ def _(
13301330
13311331# Cached state for batched SM_100 MoE GEMM
13321332_moe_batched_restype_set = False
1333- _moe_batched_cache : Optional [dict ] = None
1333+ _moe_batched_sm100_cache : Optional [dict ] = None
13341334
13351335
13361336def _ensure_moe_batched_restype ():
@@ -1346,50 +1346,130 @@ def _ensure_moe_batched_restype():
13461346 _moe_batched_restype_set = True
13471347
13481348
1349- @register_kernel ("bitsandbytes::gemm_nvfp4_moe" , "cuda" )
1350- def _ (
1349+ def _batched_moe_sm100_init_if_needed (
13511350 A_batched : torch .Tensor ,
1352- B_batched : torch .Tensor ,
1353- SFA : torch .Tensor ,
1354- SFB : torch .Tensor ,
1351+ B_all : torch .Tensor ,
1352+ SFA_batched : torch .Tensor ,
1353+ SFB_all : torch .Tensor ,
1354+ D_out : torch .Tensor ,
13551355 alpha : torch .Tensor ,
13561356 max_M : int ,
13571357 N : int ,
13581358 K : int ,
13591359 num_experts : int ,
1360- ) -> torch .Tensor :
1361- global _moe_batched_cache
1360+ stream : int ,
1361+ ) -> None :
1362+ """Call cgemm_nvfp4_moe_sm100_init if the configuration changed, else skip."""
1363+ global _moe_batched_sm100_cache
13621364 _ensure_moe_batched_restype ()
13631365
1364- key = (max_M , N , K , num_experts )
1365- if _moe_batched_cache is None or _moe_batched_cache ["key" ] != key :
1366- ws_size = lib .cgemm_nvfp4_moe_sm100_workspace_size (
1367- ct .c_int (N ), ct .c_int (max_M ), ct .c_int (K ), ct .c_int (num_experts ),
1368- )
1369- workspace = torch .empty (max (ws_size , 1 ), dtype = torch .uint8 , device = A_batched .device )
1366+ cache_key = (N , K , max_M , num_experts )
13701367
1371- ret = lib .cgemm_nvfp4_moe_sm100_init (
1372- ct .c_int (N ), ct .c_int (max_M ), ct .c_int (K ), ct .c_int (num_experts ),
1373- get_ptr (workspace ), ct .c_size_t (ws_size ),
1374- )
1375- if ret != 0 :
1376- raise RuntimeError (f"cgemm_nvfp4_moe_sm100_init failed: { ret } " )
1368+ if (_moe_batched_sm100_cache is not None
1369+ and _moe_batched_sm100_cache ["key" ] == cache_key ):
1370+ return
13771371
1378- _moe_batched_cache = {"key" : key , "workspace" : workspace }
1372+ ws_size = lib .cgemm_nvfp4_moe_sm100_workspace_size (
1373+ ct .c_int (N ), ct .c_int (max_M ), ct .c_int (K ), ct .c_int (num_experts ),
1374+ )
1375+ workspace = torch .empty (max (ws_size , 1 ), dtype = torch .uint8 , device = A_batched .device )
1376+
1377+ ret = lib .cgemm_nvfp4_moe_sm100_init (
1378+ ct .c_int (N ), ct .c_int (max_M ), ct .c_int (K ), ct .c_int (num_experts ),
1379+ get_ptr (A_batched ), get_ptr (B_all ),
1380+ get_ptr (SFA_batched ), get_ptr (SFB_all ),
1381+ get_ptr (D_out ), get_ptr (alpha ),
1382+ get_ptr (workspace ), ct .c_size_t (ws_size ), stream ,
1383+ )
1384+ if ret != 0 :
1385+ raise RuntimeError (f"cgemm_nvfp4_moe_sm100_init failed with code { ret } " )
13791386
1380- # Ensure alpha is a float32 device tensor
1381- alpha_dev = alpha .to (dtype = torch .float32 , device = A_batched .device ).contiguous ()
1387+ _moe_batched_sm100_cache = {
1388+ "key" : cache_key ,
1389+ "workspace" : workspace , # prevent GC
1390+ }
13821391
1383- D_out = torch .empty (num_experts * max_M * N , dtype = torch .bfloat16 , device = A_batched .device )
13841392
1385- ret = lib .cgemm_nvfp4_moe_sm100_run (
1386- get_ptr (A_batched ), get_ptr (B_batched ),
1387- get_ptr (SFA ), get_ptr (SFB ),
1388- get_ptr (D_out ),
1389- get_ptr (alpha_dev ),
1390- _get_tensor_stream (A_batched ),
1393+ def _gemm_nvfp4_batched_moe_sm100_raw (
1394+ A_batched : torch .Tensor ,
1395+ B_all : torch .Tensor ,
1396+ SFA_batched : torch .Tensor ,
1397+ SFB_all : torch .Tensor ,
1398+ D_out : torch .Tensor ,
1399+ alpha : torch .Tensor ,
1400+ max_M : int ,
1401+ N : int ,
1402+ K : int ,
1403+ num_experts : int ,
1404+ ) -> None :
1405+ """Raw batched MoE NVFP4 GEMM — init-if-needed then run.
1406+
1407+ All buffers must be pre-allocated. D_out must be BF16 of shape (num_experts * max_M, N).
1408+ alpha must be a float32 device tensor of shape (1,) containing A_scale * B_scale.
1409+ """
1410+ stream = _get_tensor_stream (A_batched )
1411+ _batched_moe_sm100_init_if_needed (
1412+ A_batched , B_all , SFA_batched , SFB_all , D_out , alpha ,
1413+ max_M , N , K , num_experts , stream ,
13911414 )
1415+ ret = lib .cgemm_nvfp4_moe_sm100_run (stream )
13921416 if ret != 0 :
1393- raise RuntimeError (f"cgemm_nvfp4_moe_sm100_run failed: { ret } " )
1417+ raise RuntimeError (f"cgemm_nvfp4_moe_sm100_run failed with code { ret } " )
1418+
13941419
1420+ @register_kernel ("bitsandbytes::gemm_nvfp4_moe" , "cuda" )
1421+ def _ (
1422+ A_batched : torch .Tensor ,
1423+ B_batched : torch .Tensor ,
1424+ SFA : torch .Tensor ,
1425+ SFB : torch .Tensor ,
1426+ alpha : torch .Tensor ,
1427+ max_M : int ,
1428+ N : int ,
1429+ K : int ,
1430+ num_experts : int ,
1431+ ) -> torch .Tensor :
1432+ with _cuda_device_of (A_batched ):
1433+ D_out = torch .empty (num_experts * max_M , N , dtype = torch .bfloat16 , device = A_batched .device )
1434+ _gemm_nvfp4_batched_moe_sm100_raw (
1435+ A_batched , B_batched , SFA , SFB , D_out , alpha ,
1436+ max_M , N , K , num_experts ,
1437+ )
13951438 return D_out .view (num_experts , max_M , N )
1439+
1440+
1441+ @register_kernel ("bitsandbytes::moe_weighted_gather_bf16" , "cuda" )
1442+ def _ (
1443+ D_batched : torch .Tensor ,
1444+ output_bf16 : torch .Tensor ,
1445+ workspace_fp32 : torch .Tensor ,
1446+ token_ids : torch .Tensor ,
1447+ expert_ids : torch .Tensor ,
1448+ slot_ids : torch .Tensor ,
1449+ weights : torch .Tensor ,
1450+ num_tokens : int ,
1451+ max_M : int ,
1452+ N : int ,
1453+ ) -> torch .Tensor :
1454+ """Fused gather + weight + FP32 accumulate + BF16 convert.
1455+
1456+ Internally launches: memset(workspace) -> atomicAdd gather -> FP32->BF16 convert.
1457+ All three operations on the same stream, capturable in a CUDA graph.
1458+ """
1459+ total_assignments = token_ids .shape [0 ]
1460+ with _cuda_device_of (D_batched ):
1461+ lib .cmoe_weighted_gather_bf16 (
1462+ get_ptr (D_batched ),
1463+ get_ptr (output_bf16 ),
1464+ get_ptr (workspace_fp32 ),
1465+ get_ptr (token_ids ),
1466+ get_ptr (expert_ids ),
1467+ get_ptr (slot_ids ),
1468+ get_ptr (weights ),
1469+ ct .c_int (total_assignments ),
1470+ ct .c_int (num_tokens ),
1471+ ct .c_int (max_M ),
1472+ ct .c_int (N ),
1473+ _get_tensor_stream (D_batched ),
1474+ )
1475+ return output_bf16
0 commit comments