@@ -1150,129 +1150,6 @@ def _gemm_nvfp4_grouped_raw(
11501150 )
11511151
11521152
1153- # Cached state for grouped SM_100 GEMM
1154- _grouped_restype_set = False
1155-
1156- # Cached buffers for the fused C dispatch (keyed by (N, K, num_experts),
1157- # sized for worst-case routing so the cache always hits after first call)
1158- _grouped_fused_cache : Optional [dict ] = None
1159-
1160-
1161- def _get_fused_buffers (
1162- total_tokens : int , N : int , K : int , num_experts : int , device : torch .device ,
1163- ) -> dict :
1164- """Get or grow cached device buffers for the fused C dispatch.
1165-
1166- Buffers are sized for worst-case token routing (all tokens to one expert),
1167- keyed on (N, K, num_experts). Grows if total_tokens exceeds the cached size.
1168- """
1169- global _grouped_fused_cache , _grouped_restype_set
1170-
1171- if not _grouped_restype_set :
1172- lib .cgemm_nvfp4_grouped_sm100_meta_size .restype = ct .c_size_t
1173- lib .cgemm_nvfp4_grouped_sm100_workspace_size .restype = ct .c_size_t
1174- _grouped_restype_set = True
1175-
1176- if (_grouped_fused_cache is not None
1177- and _grouped_fused_cache ["N" ] == N
1178- and _grouped_fused_cache ["K" ] == K
1179- and _grouped_fused_cache ["num_experts" ] == num_experts
1180- and _grouped_fused_cache ["max_tokens" ] >= total_tokens ):
1181- return _grouped_fused_cache
1182-
1183- scale_W = K // 16
1184- n_col_blocks = (scale_W + 3 ) // 4
1185-
1186- # Worst-case SFA output: each expert adds at most 1 extra 128-row block
1187- max_row_blocks = (total_tokens + 127 ) // 128 + num_experts
1188- sfa_out_bytes = max_row_blocks * n_col_blocks * 512
1189-
1190- sfa_swizzle_out = torch .empty (max (sfa_out_bytes , 1 ), dtype = torch .uint8 , device = device )
1191- sfa_swizzle_meta = torch .empty (3 * num_experts * 4 , dtype = torch .uint8 , device = device )
1192-
1193- meta_size = lib .cgemm_nvfp4_grouped_sm100_meta_size (ct .c_int (num_experts ))
1194- gemm_meta_buf = torch .empty (meta_size , dtype = torch .uint8 , device = device )
1195-
1196- # Worst-case workspace: all tokens routed to a single expert
1197- M_arr = (ct .c_int * num_experts )(* ([0 ] * num_experts ))
1198- M_arr [0 ] = total_tokens
1199- ws_size = lib .cgemm_nvfp4_grouped_sm100_workspace_size (
1200- M_arr , ct .c_int (N ), ct .c_int (K ), ct .c_int (num_experts ),
1201- )
1202- workspace_buf = torch .empty (max (ws_size , 1 ), dtype = torch .uint8 , device = device )
1203-
1204- _grouped_fused_cache = {
1205- "N" : N , "K" : K , "num_experts" : num_experts , "max_tokens" : total_tokens ,
1206- "sfa_swizzle_out" : sfa_swizzle_out ,
1207- "sfa_swizzle_meta" : sfa_swizzle_meta ,
1208- "gemm_meta_buf" : gemm_meta_buf ,
1209- "workspace_buf" : workspace_buf ,
1210- "ws_size" : ws_size ,
1211- }
1212- return _grouped_fused_cache
1213-
1214-
1215- def _gemm_nvfp4_grouped_sm100 (
1216- A_concat : torch .Tensor ,
1217- B_all : torch .Tensor ,
1218- SFA_rowmajor : torch .Tensor ,
1219- SFB_all : torch .Tensor ,
1220- offsets_host : tuple [int , ...],
1221- A_tensor_scale : float ,
1222- B_tensor_scale : float ,
1223- N : int ,
1224- K : int ,
1225- num_experts : int ,
1226- ) -> torch .Tensor :
1227- """SM_100 grouped NVFP4 GEMM using fused C dispatch.
1228-
1229- Single ctypes call handles SFA swizzle + CUTLASS grouped GEMM.
1230- All metadata computation and pointer building happens in C.
1231- Python only allocates output and passes pre-cached buffers.
1232-
1233- SFB_all is already per-expert swizzled (each expert was independently
1234- quantized by quantize_nvfp4, which swizzles each expert's scales
1235- separately). No conversion needed.
1236-
1237- offsets_host: host-side tuple of cumulative token offsets (num_experts + 1 ints).
1238- """
1239- device = A_concat .device
1240- total_tokens = offsets_host [- 1 ]
1241-
1242- # Get or grow cached buffers (keyed on N, K, num_experts — always hits
1243- # after first call unless total_tokens grows)
1244- buf = _get_fused_buffers (total_tokens , N , K , num_experts , device )
1245-
1246- # Output (BF16 — CUTLASS accumulates in FP32, epilogue outputs BF16)
1247- D_concat = torch .empty (total_tokens , N , dtype = torch .bfloat16 , device = device )
1248-
1249- # Build host offsets ctypes array (per-call, ~1μs for 9 ints)
1250- host_offsets_arr = (ct .c_int * (num_experts + 1 ))(* offsets_host )
1251-
1252- # Single fused C call: SFA swizzle + metadata build + GEMM launch
1253- # SFB_all is passed directly — already per-expert swizzled from quantize_nvfp4
1254- lib .cgemm_nvfp4_grouped_sm100_fused (
1255- get_ptr (A_concat ),
1256- get_ptr (B_all ),
1257- get_ptr (SFA_rowmajor ),
1258- get_ptr (SFB_all ),
1259- get_ptr (D_concat ),
1260- host_offsets_arr ,
1261- ct .c_int (N ),
1262- ct .c_int (K ),
1263- ct .c_int (num_experts ),
1264- ct .c_float (A_tensor_scale * B_tensor_scale ),
1265- get_ptr (buf ["sfa_swizzle_out" ]),
1266- get_ptr (buf ["sfa_swizzle_meta" ]),
1267- get_ptr (buf ["gemm_meta_buf" ]),
1268- get_ptr (buf ["workspace_buf" ]),
1269- ct .c_size_t (buf ["ws_size" ]),
1270- _get_tensor_stream (A_concat ),
1271- )
1272-
1273- return D_concat
1274-
1275-
12761153@register_kernel ("bitsandbytes::gemm_nvfp4_grouped" , "cuda" )
12771154def _ (
12781155 A_concat : torch .Tensor ,
@@ -1293,17 +1170,6 @@ def _(
12931170 SFB_all: per-expert swizzled weight scales (each expert independently swizzled
12941171 by quantize_nvfp4, then concatenated).
12951172 """
1296- # SM_100 (datacenter Blackwell): use CUTLASS grouped GEMM
1297- major , _ = torch .cuda .get_device_capability (A_concat .device )
1298- if major == 10 and hasattr (lib , "cgemm_nvfp4_grouped_cutlass_sm100" ):
1299- # Convert device offsets to host tuple (cheap for small arrays,
1300- # but callers should migrate to passing host offsets directly)
1301- offsets_host = tuple (expert_offsets .tolist ())
1302- return _gemm_nvfp4_grouped_sm100 (
1303- A_concat , B_all , SFA_rowmajor , SFB_all , offsets_host ,
1304- A_tensor_scale , B_tensor_scale , N , K , num_experts ,
1305- )
1306-
13071173 # SM_120 (consumer Blackwell): use hand-written grouped kernel
13081174 # SM_120 expects globally-swizzled SFA, so swizzle the row-major input
13091175 total_tokens = A_concat .numel () // (K // 2 )
0 commit comments