@@ -910,66 +910,66 @@ def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
910910 def _forward_batched (self , x : torch .Tensor , expert_offsets : torch .Tensor ) -> torch .Tensor :
911911 """Batched GEMM path (SM_100 datacenter Blackwell).
912912
913- Scatters tokens into padded (num_experts, max_M, K) layout, quantizes
914- per-expert, runs a single batched GEMM kernel, then gathers results.
913+ 6-kernel pipeline with zero .item() in the compute path:
914+ 1. abs().max() — PyTorch reduction, stays on GPU
915+ 2. quantize_nvfp4_raw — quantize all tokens in one launch
916+ 3. moe_scatter_nvfp4 — FP4 concat → padded per-expert
917+ 4. scale_to_blocked_batched — row-major → swizzled per-expert
918+ 5. gemm_nvfp4_moe — batched GEMM with device-side alpha
919+ 6. moe_gather_bf16 — padded per-expert → concat
915920 """
916- from bitsandbytes .functional import gemm_nvfp4_moe , quantize_nvfp4
921+ from bitsandbytes .functional import (
922+ gemm_nvfp4_moe ,
923+ moe_gather_bf16 ,
924+ moe_scatter_nvfp4 ,
925+ quantize_nvfp4_raw ,
926+ scale_to_blocked_batched ,
927+ )
917928
918929 inp_dtype = x .dtype
919930 N , K = self .output_features , self .input_features
920931 num_experts = self .num_experts
921932
922933 expert_offsets_i32 = expert_offsets .to (torch .int32 )
923934 tokens_per_expert = expert_offsets_i32 [1 :] - expert_offsets_i32 [:- 1 ]
935+ # .item() for shape computation — needed for tensor allocation
924936 raw_max_M = tokens_per_expert .max ().item ()
925- # Pad to multiple of 128 for CUTLASS tile efficiency
926937 max_M = ((raw_max_M + 127 ) // 128 ) * 128
938+ total_tokens = expert_offsets_i32 [- 1 ].item ()
927939
928940 x_2d = x .reshape (- 1 , K ).to (torch .bfloat16 ).contiguous ()
929941
930- # Shared tensor scale across all experts (matches grouped GEMM behavior)
931- act_tensor_scale = x_2d .abs ().max ().item ()
932-
933- # Quantize per-expert with shared tensor scale
934- all_act_packed = []
935- all_act_scales = []
936-
937- for i in range (num_experts ):
938- start = expert_offsets_i32 [i ].item ()
939- end = expert_offsets_i32 [i + 1 ].item ()
940- n_tok = end - start
941-
942- act_padded = torch .zeros (max_M , K , dtype = torch .bfloat16 , device = x .device )
943- if n_tok > 0 :
944- act_padded [:n_tok ] = x_2d [start :end ]
942+ # 1. Compute tensor scale on GPU (no .item(), stays as device tensor)
943+ act_tensor_scale_dev = x_2d .abs ().max ()
944+ global_scale_dev = (1.0 / act_tensor_scale_dev ).to (torch .float32 )
945945
946- act_packed , act_state = quantize_nvfp4 (act_padded , tensor_scale = act_tensor_scale )
947- all_act_packed .append (act_packed )
948- all_act_scales .append (act_state .block_scales_blocked )
946+ # 2. Quantize ALL concatenated tokens in one launch
947+ packed_all , scales_all = quantize_nvfp4_raw (x_2d , global_scale_dev )
949948
950- A_batched = torch .cat (all_act_packed )
951- SFA_batched = torch .cat (all_act_scales )
949+ # 3. Scatter: FP4 data from concatenated to padded per-expert layout
950+ packed_batched = moe_scatter_nvfp4 (
951+ packed_all , expert_offsets_i32 , max_M , K , num_experts ,
952+ )
952953
953- # Run batched GEMM (alpha is a device tensor for graph safety)
954- alpha_dev = torch .tensor (
955- [act_tensor_scale * self .weight_tensor_scale ],
956- dtype = torch .float32 , device = x .device ,
954+ # 4. Swizzle scales: row-major → per-expert batched CUTLASS layout
955+ sfa_batched = scale_to_blocked_batched (
956+ scales_all , expert_offsets_i32 , max_M , K , num_experts ,
957957 )
958+
959+ # 5. Batched GEMM with device-side alpha (no .item() sync)
960+ alpha_dev = (act_tensor_scale_dev * self .weight_tensor_scale ).to (torch .float32 )
958961 D = gemm_nvfp4_moe (
959- A_batched , SFA_batched , alpha_dev ,
962+ packed_batched , sfa_batched , alpha_dev ,
960963 self .weight_packed , self .weight_scales_batched ,
961964 max_M , N , K , num_experts ,
962965 )
963966
964- # Gather results: D is (num_experts, max_M, N)
965- total_tokens = expert_offsets_i32 [- 1 ].item ()
966- out = torch .empty (total_tokens , N , dtype = torch .bfloat16 , device = x .device )
967- for i in range (num_experts ):
968- start = expert_offsets_i32 [i ].item ()
969- end = expert_offsets_i32 [i + 1 ].item ()
970- n_tok = end - start
971- if n_tok > 0 :
972- out [start :end ] = D [i , :n_tok ]
967+ # 6. Gather: padded per-expert BF16 → concatenated output
968+ D_flat = D .view (- 1 ).contiguous ()
969+ out = moe_gather_bf16 (
970+ D_flat , expert_offsets_i32 , max_M , N , num_experts , total_tokens ,
971+ )
972+ out = out .view (total_tokens , N )
973973
974974 if self .bias is not None :
975975 for i in range (num_experts ):
0 commit comments