@@ -1111,43 +1111,100 @@ def _(
11111111 )
11121112
11131113
1114- # Grouped NVFP4 GEMM for MoE inference (SM_120+ )
1114+ # Batched NVFP4 GEMM for MoE inference (SM_120, CUDA-graph-compatible )
11151115#
1116- # Fuses all expert GEMMs into a single kernel launch using expert-offset
1117- # work decomposition with binary search. Uses swizzled (block-scaled) scales.
1118- # CUDA-graph-safe: no dynamic allocations.
1119- def _gemm_nvfp4_grouped_raw (
1120- A_concat : torch .Tensor ,
1116+ # Fixed-padding approach: all experts compute max_M rows. Padded rows produce
1117+ # ignored output that the caller discards.
1118+ # Uses CUTLASS batched GEMM with init/run split for CUDA graph support.
1119+ #
1120+ # Cache: stores the last (N, K, max_M, num_experts) init configuration.
1121+ # On cache hit, skips init and just calls run. On miss, re-inits.
1122+ _batched_moe_sm120_cache : Optional [dict ] = None
1123+ _batched_moe_restype_set = False
1124+
1125+
1126+ def _ensure_batched_moe_restypes ():
1127+ global _batched_moe_restype_set
1128+ if not _batched_moe_restype_set :
1129+ lib .cgemm_nvfp4_moe_sm120_sfa_size .restype = ct .c_size_t
1130+ lib .cgemm_nvfp4_moe_sm120_sfb_size .restype = ct .c_size_t
1131+ lib .cgemm_nvfp4_moe_sm120_workspace_size .restype = ct .c_size_t
1132+ lib .cgemm_nvfp4_moe_sm120_init .restype = ct .c_int
1133+ lib .cgemm_nvfp4_moe_sm120_run .restype = ct .c_int
1134+ lib .cgemm_nvfp4_moe_sm120_sfa_size_per_expert .restype = ct .c_size_t
1135+ lib .cgemm_nvfp4_moe_sm120_sfb_size_per_expert .restype = ct .c_size_t
1136+ _batched_moe_restype_set = True
1137+
1138+
1139+ def _batched_moe_sm120_init_if_needed (
1140+ A_batched : torch .Tensor ,
11211141 B_all : torch .Tensor ,
1122- SFA_concat : torch .Tensor ,
1142+ SFA_batched : torch .Tensor ,
11231143 SFB_all : torch .Tensor ,
1124- D_concat : torch .Tensor ,
1125- expert_offsets : torch .Tensor ,
1126- cumul_m_tiles : torch .Tensor ,
1144+ D_out : torch .Tensor ,
1145+ alpha : torch .Tensor ,
1146+ max_M : int ,
1147+ N : int ,
1148+ K : int ,
1149+ num_experts : int ,
1150+ stream : int ,
1151+ ) -> None :
1152+ """Call cgemm_nvfp4_moe_sm120_init if the configuration changed, else skip."""
1153+ global _batched_moe_sm120_cache
1154+ _ensure_batched_moe_restypes ()
1155+
1156+ cache_key = (N , K , max_M , num_experts )
1157+
1158+ if (_batched_moe_sm120_cache is not None
1159+ and _batched_moe_sm120_cache ["key" ] == cache_key ):
1160+ return
1161+
1162+ ws_size = lib .cgemm_nvfp4_moe_sm120_workspace_size (
1163+ ct .c_int (N ), ct .c_int (max_M ), ct .c_int (K ), ct .c_int (num_experts ),
1164+ )
1165+ workspace = torch .empty (max (ws_size , 1 ), dtype = torch .uint8 , device = A_batched .device )
1166+
1167+ ret = lib .cgemm_nvfp4_moe_sm120_init (
1168+ ct .c_int (N ), ct .c_int (max_M ), ct .c_int (K ), ct .c_int (num_experts ),
1169+ get_ptr (A_batched ), get_ptr (B_all ),
1170+ get_ptr (SFA_batched ), get_ptr (SFB_all ),
1171+ get_ptr (D_out ), get_ptr (alpha ),
1172+ get_ptr (workspace ), ct .c_size_t (ws_size ), stream ,
1173+ )
1174+ if ret != 0 :
1175+ raise RuntimeError (f"cgemm_nvfp4_moe_sm120_init failed with code { ret } " )
1176+
1177+ _batched_moe_sm120_cache = {
1178+ "key" : cache_key ,
1179+ "workspace" : workspace , # prevent GC
1180+ }
1181+
1182+
1183+ def _gemm_nvfp4_batched_moe_raw (
1184+ A_batched : torch .Tensor ,
1185+ B_all : torch .Tensor ,
1186+ SFA_batched : torch .Tensor ,
1187+ SFB_all : torch .Tensor ,
1188+ D_out : torch .Tensor ,
1189+ alpha : torch .Tensor ,
1190+ max_M : int ,
11271191 N : int ,
11281192 K : int ,
11291193 num_experts : int ,
1130- total_tiles : int ,
11311194) -> None :
1132- """Raw grouped NVFP4 GEMM (BF16 output) — zero allocations, CUDA-graph-safe .
1195+ """Raw batched MoE NVFP4 GEMM — init-if-needed then run .
11331196
1134- All buffers must be pre-allocated. D_concat must be BF16 of shape (total_tokens , N).
1135- expert_offsets and cumul_m_tiles must be int32 on the same device .
1197+ All buffers must be pre-allocated. D_out must be BF16 of shape (num_experts * max_M , N).
1198+ alpha must be a float32 device tensor of shape (1,) containing A_scale * B_scale .
11361199 """
1137- lib .cgemm_nvfp4_grouped_bf16 (
1138- get_ptr (A_concat ),
1139- get_ptr (B_all ),
1140- get_ptr (SFA_concat ),
1141- get_ptr (SFB_all ),
1142- get_ptr (D_concat ),
1143- get_ptr (expert_offsets ),
1144- get_ptr (cumul_m_tiles ),
1145- ct .c_int (N ),
1146- ct .c_int (K ),
1147- ct .c_int (num_experts ),
1148- ct .c_int (total_tiles ),
1149- _get_tensor_stream (A_concat ),
1200+ stream = _get_tensor_stream (A_batched )
1201+ _batched_moe_sm120_init_if_needed (
1202+ A_batched , B_all , SFA_batched , SFB_all , D_out , alpha ,
1203+ max_M , N , K , num_experts , stream ,
11501204 )
1205+ ret = lib .cgemm_nvfp4_moe_sm120_run (stream )
1206+ if ret != 0 :
1207+ raise RuntimeError (f"cgemm_nvfp4_moe_sm120_run failed with code { ret } " )
11511208
11521209
11531210# Cached state for grouped SM_100 GEMM
@@ -1304,23 +1361,38 @@ def _(
13041361 A_tensor_scale , B_tensor_scale , N , K , num_experts ,
13051362 )
13061363
1307- # SM_120 (consumer Blackwell): use hand-written grouped kernel
1308- # SM_120 expects globally-swizzled SFA, so swizzle the row-major input
1309- total_tokens = A_concat .numel () // (K // 2 )
1310- scale_W = K // 16
1311- SFA_blocked = torch .ops .bitsandbytes .scale_to_blocked (SFA_rowmajor , total_tokens , scale_W )
1364+ # SM_120 (consumer Blackwell): deprecated grouped path.
1365+ # Use gemm_nvfp4_batched_moe (fixed-padding) instead.
1366+ raise NotImplementedError (
1367+ "SM_120 grouped (variable-offset) NVFP4 MoE GEMM has been removed. "
1368+ "Use bitsandbytes::gemm_nvfp4_batched_moe with fixed-padding instead."
1369+ )
13121370
1313- num_n_tiles = (N + 127 ) // 128
13141371
1315- with _cuda_device_of (A_concat ):
1316- D_concat = torch .empty (total_tokens , N , dtype = torch .bfloat16 , device = A_concat .device )
1317- total_tiles = cumul_m_tiles [- 1 ].item () * num_n_tiles
1372+ @register_kernel ("bitsandbytes::gemm_nvfp4_batched_moe" , "cuda" )
1373+ def _ (
1374+ A_batched : torch .Tensor ,
1375+ B_all : torch .Tensor ,
1376+ SFA_batched : torch .Tensor ,
1377+ SFB_all : torch .Tensor ,
1378+ alpha : torch .Tensor ,
1379+ max_M : int ,
1380+ N : int ,
1381+ K : int ,
1382+ num_experts : int ,
1383+ ) -> torch .Tensor :
1384+ """Batched NVFP4 GEMM for MoE: all experts compute max_M rows.
13181385
1319- _gemm_nvfp4_grouped_raw (
1320- A_concat , B_all , SFA_blocked , SFB_all , D_concat ,
1321- expert_offsets , cumul_m_tiles , N , K , num_experts , total_tiles ,
1386+ A_batched: flat packed FP4 activations, (num_experts * max_M * K/2) bytes.
1387+ B_all: flat packed FP4 weights, (num_experts * N * K/2) bytes.
1388+ SFA_batched: pre-swizzled activation scales (CUTLASS block-scaled layout).
1389+ SFB_all: pre-swizzled weight scales (CUTLASS block-scaled layout).
1390+ alpha: float32 device tensor [1], = A_tensor_scale * B_tensor_scale.
1391+ """
1392+ with _cuda_device_of (A_batched ):
1393+ D_out = torch .empty (num_experts * max_M , N , dtype = torch .bfloat16 , device = A_batched .device )
1394+ _gemm_nvfp4_batched_moe_raw (
1395+ A_batched , B_all , SFA_batched , SFB_all , D_out , alpha ,
1396+ max_M , N , K , num_experts ,
13221397 )
1323-
1324- # Apply tensor scales (SM_120 kernel has no alpha epilogue)
1325- D_concat *= A_tensor_scale * B_tensor_scale
1326- return D_concat
1398+ return D_out
0 commit comments