@@ -910,18 +910,25 @@ def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
910910 def _forward_batched (self , x : torch .Tensor , expert_offsets : torch .Tensor ) -> torch .Tensor :
911911 """Batched GEMM path (SM_100 datacenter Blackwell).
912912
913- 6-kernel pipeline with zero .item() in the compute path:
914- 1. abs().max() — PyTorch reduction, stays on GPU
915- 2. quantize_nvfp4_raw — quantize all tokens in one launch
916- 3. moe_scatter_nvfp4 — FP4 concat → padded per-expert
917- 4. scale_to_blocked_batched — row-major → swizzled per-expert
918- 5. gemm_nvfp4_moe — batched GEMM with device-side alpha
919- 6. moe_gather_bf16 — padded per-expert → concat
913+ Pipeline with init/run split for CUDA graph compatibility:
914+ 1. abs().max() — compute tensor scale (device-side)
915+ 2. quantize_nvfp4_raw — quantize all tokens in one launch
916+ 3. cmoe_scatter_nvfp4 — FP4 data → persistent padded buffer
917+ 4. scale_to_blocked_batched — scales → persistent swizzled buffer
918+ 5. batched GEMM run() — init-if-needed, then just run(stream)
919+ 6. moe_gather_bf16 — padded per-expert → concatenated output
920+
921+ All persistent buffers (A, SFA, D, alpha) are cached in the module
922+ so their addresses are stable for the CUTLASS init/run split.
920923 """
924+ import ctypes as ct
925+
926+ from bitsandbytes .backends .cuda .ops import _gemm_nvfp4_batched_moe_sm100_raw
927+ from bitsandbytes .cextension import lib
921928 from bitsandbytes .functional import (
922- gemm_nvfp4_moe ,
929+ _get_tensor_stream ,
930+ get_ptr ,
923931 moe_gather_bf16 ,
924- moe_scatter_nvfp4 ,
925932 quantize_nvfp4_raw ,
926933 scale_to_blocked_batched ,
927934 )
@@ -946,28 +953,84 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
946953 # 2. Quantize ALL concatenated tokens in one launch
947954 packed_all , scales_all = quantize_nvfp4_raw (x_2d , global_scale_dev )
948955
949- # 3. Scatter: FP4 data from concatenated to padded per-expert layout
950- packed_batched = moe_scatter_nvfp4 (
951- packed_all , expert_offsets_i32 , max_M , K , num_experts ,
956+ # 3. Ensure persistent cached buffers exist (stable pointers for init/run)
957+ cache_key = (max_M , N , K , num_experts )
958+ if not hasattr (self , "_batched_cache" ) or self ._batched_cache .get ("key" ) != cache_key :
959+ dev = x .device
960+ W = K // 16
961+ n_col_blocks = (W + 3 ) // 4
962+ n_row_blocks = (max_M + 127 ) // 128
963+ sfa_per_expert = n_row_blocks * n_col_blocks * 512
964+ sfa_total = num_experts * sfa_per_expert
965+
966+ self ._batched_cache = {
967+ "key" : cache_key ,
968+ "A_batched" : torch .empty (num_experts * max_M * (K // 2 ), dtype = torch .uint8 , device = dev ),
969+ "SFA_batched" : torch .zeros (sfa_total , dtype = torch .uint8 , device = dev ),
970+ "D_out" : torch .empty (num_experts * max_M , N , dtype = torch .bfloat16 , device = dev ),
971+ "alpha_dev" : torch .empty (1 , dtype = torch .float32 , device = dev ),
972+ }
973+ cache = self ._batched_cache
974+
975+ stream = _get_tensor_stream (x_2d )
976+
977+ # 4. Scatter FP4 data into persistent buffer
978+ lib .cmoe_scatter_nvfp4 (
979+ get_ptr (packed_all ),
980+ get_ptr (cache ["A_batched" ]),
981+ get_ptr (expert_offsets_i32 ),
982+ ct .c_int (max_M ),
983+ ct .c_int (K ),
984+ ct .c_int (num_experts ),
985+ stream ,
952986 )
953987
954- # 4. Swizzle scales: row-major → per-expert batched CUTLASS layout
955- sfa_batched = scale_to_blocked_batched (
956- scales_all , expert_offsets_i32 , max_M , K , num_experts ,
988+ # 5. Swizzle scales per-expert into persistent buffer
989+ W = K // 16
990+ n_col_blocks = (W + 3 ) // 4
991+ n_row_blocks = (max_M + 127 ) // 128
992+ sfa_per_expert = n_row_blocks * n_col_blocks * 512
993+ sfa_total = num_experts * sfa_per_expert
994+
995+ expert_row_offsets = expert_offsets_i32 [:- 1 ]
996+ expert_M_dev = tokens_per_expert .to (torch .int32 )
997+ expert_out_offsets = torch .arange (
998+ num_experts , dtype = torch .int32 , device = x .device ,
999+ ) * sfa_per_expert
1000+
1001+ # Zero persistent SFA buffer, then swizzle into it
1002+ cache ["SFA_batched" ].zero_ ()
1003+ lib .cscale_to_blocked_batched (
1004+ get_ptr (scales_all ),
1005+ get_ptr (cache ["SFA_batched" ]),
1006+ get_ptr (expert_row_offsets ),
1007+ get_ptr (expert_M_dev ),
1008+ get_ptr (expert_out_offsets ),
1009+ ct .c_int (W ),
1010+ ct .c_int (num_experts ),
1011+ ct .c_int (n_row_blocks ),
1012+ stream ,
9571013 )
9581014
959- # 5. Batched GEMM with device-side alpha (no .item() sync)
960- alpha_dev = (act_tensor_scale_dev * self .weight_tensor_scale ).to (torch .float32 )
961- D = gemm_nvfp4_moe (
962- packed_batched , sfa_batched , alpha_dev ,
963- self .weight_packed , self .weight_scales_batched ,
1015+ # 6. Set alpha (device-side, no .item() sync)
1016+ cache ["alpha_dev" ].copy_ (
1017+ (act_tensor_scale_dev * self .weight_tensor_scale ).to (torch .float32 ).reshape (1 )
1018+ )
1019+
1020+ # 7. Batched GEMM (init-if-needed, then just run(stream))
1021+ _gemm_nvfp4_batched_moe_sm100_raw (
1022+ cache ["A_batched" ],
1023+ self .weight_packed ,
1024+ cache ["SFA_batched" ],
1025+ self .weight_scales_batched ,
1026+ cache ["D_out" ],
1027+ cache ["alpha_dev" ],
9641028 max_M , N , K , num_experts ,
9651029 )
9661030
967- # 6. Gather: padded per-expert BF16 → concatenated output
968- D_flat = D .view (- 1 ).contiguous ()
1031+ # 8. Gather: padded per-expert BF16 → concatenated output
9691032 out = moe_gather_bf16 (
970- D_flat , expert_offsets_i32 , max_M , N , num_experts , total_tokens ,
1033+ cache [ "D_out" ]. view ( - 1 ) , expert_offsets_i32 , max_M , N , num_experts , total_tokens ,
9711034 )
9721035 out = out .view (total_tokens , N )
9731036
0 commit comments