@@ -852,7 +852,15 @@ def _quantize_weights(self):
852852 requires_grad = False ,
853853 )
854854
855- def forward (self , x : torch .Tensor , expert_offsets : torch .Tensor ) -> torch .Tensor :
855+ def forward (
856+ self ,
857+ x : torch .Tensor ,
858+ expert_offsets : torch .Tensor ,
859+ * ,
860+ token_ids : Optional [torch .Tensor ] = None ,
861+ gating_weights : Optional [torch .Tensor ] = None ,
862+ num_dest_tokens : Optional [int ] = None ,
863+ ) -> torch .Tensor :
856864 """Run NVFP4 GEMM across all experts.
857865
858866 Uses batched GEMM on SM_100 (datacenter Blackwell) or grouped GEMM
@@ -864,17 +872,30 @@ def forward(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor
864872 expert_offsets: Cumulative token offsets [num_experts + 1], int32.
865873 expert_offsets[i] is the starting token index for expert i.
866874 expert_offsets[-1] = total_tokens.
875+ token_ids: Optional mapping from assignment index to output token index
876+ [total_tokens] (int32). Required for weighted gather.
877+ gating_weights: Optional per-assignment gating weights [total_tokens] (float32).
878+ Required for weighted gather.
879+ num_dest_tokens: Number of unique destination tokens in the output.
880+ Required when token_ids and gating_weights are provided.
867881
868882 Returns:
869- Output tensor [total_tokens, N] with expert results in the same token order.
883+ If token_ids and gating_weights are provided:
884+ Weighted output tensor [num_dest_tokens, N] with fused gather + weight + sum.
885+ Otherwise:
886+ Output tensor [total_tokens, N] with per-assignment expert results.
870887 """
871888 if not self ._quantized :
872889 self ._quantize_weights ()
873890
874891 major , _ = torch .cuda .get_device_capability (x .device )
875892 from bitsandbytes .cextension import lib
876893 if major == 10 and hasattr (lib , "cgemm_nvfp4_moe_sm100_init" ):
877- return self ._forward_batched (x , expert_offsets )
894+ return self ._forward_batched (
895+ x , expert_offsets ,
896+ token_ids = token_ids , gating_weights = gating_weights ,
897+ num_dest_tokens = num_dest_tokens ,
898+ )
878899 return self ._forward_grouped (x , expert_offsets )
879900
880901 def _forward_grouped (self , x : torch .Tensor , expert_offsets : torch .Tensor ) -> torch .Tensor :
@@ -899,27 +920,35 @@ def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
899920 )
900921
901922 if self .bias is not None :
902- for i in range (self .num_experts ):
903- start = expert_offsets [i ].item ()
904- end = expert_offsets [i + 1 ].item ()
905- if end > start :
906- out [start :end ] = out [start :end ] + self .bias [i ].to (out .dtype )
923+ expert_offsets_i32 = expert_offsets .to (torch .int32 )
924+ tokens_per_expert = expert_offsets_i32 [1 :] - expert_offsets_i32 [:- 1 ]
925+ bias_expanded = torch .repeat_interleave (self .bias , tokens_per_expert , dim = 0 )
926+ out = out + bias_expanded .to (out .dtype )
907927
908928 return out .to (inp_dtype )
909929
910- def _forward_batched (self , x : torch .Tensor , expert_offsets : torch .Tensor ) -> torch .Tensor :
930+ def _forward_batched (
931+ self ,
932+ x : torch .Tensor ,
933+ expert_offsets : torch .Tensor ,
934+ * ,
935+ token_ids : Optional [torch .Tensor ] = None ,
936+ gating_weights : Optional [torch .Tensor ] = None ,
937+ num_dest_tokens : Optional [int ] = None ,
938+ ) -> torch .Tensor :
911939 """Batched GEMM path (SM_100 datacenter Blackwell).
912940
913941 Pipeline with init/run split for CUDA graph compatibility:
914- 1. abs().max() — compute tensor scale (device-side)
915- 2. quantize_nvfp4_raw — quantize all tokens in one launch
916- 3. cmoe_scatter_nvfp4 — FP4 data → persistent padded buffer
942+ 1. abs().max() — compute tensor scale (device-side)
943+ 2. quantize_nvfp4_raw — quantize all tokens in one launch
944+ 3. cmoe_scatter_nvfp4 — FP4 data → persistent padded buffer
917945 4. scale_to_blocked_batched — scales → persistent swizzled buffer
918- 5. batched GEMM run() — init-if-needed, then just run(stream)
919- 6. moe_gather_bf16 — padded per-expert → concatenated output
946+ 5. batched GEMM run() — init-if-needed, then just run(stream)
947+ 6. gather — weighted or unweighted depending on args
920948
921- All persistent buffers (A, SFA, D, alpha) are cached in the module
922- so their addresses are stable for the CUTLASS init/run split.
949+ All persistent buffers (A, SFA, D, alpha, gather workspace) are cached
950+ in the module so their addresses are stable for the CUTLASS init/run split.
951+ No .item() GPU-CPU sync on the common (decode) path.
923952 """
924953 import ctypes as ct
925954
@@ -928,21 +957,29 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
928957 from bitsandbytes .functional import (
929958 _get_tensor_stream ,
930959 get_ptr ,
931- moe_gather_bf16 ,
932960 quantize_nvfp4_raw ,
933- scale_to_blocked_batched ,
934961 )
935962
936963 inp_dtype = x .dtype
937964 N , K = self .output_features , self .input_features
938965 num_experts = self .num_experts
966+ total_tokens = x .shape [0 ] # CPU int, no GPU sync
967+ use_weighted = token_ids is not None and gating_weights is not None
968+ dev = x .device
939969
940970 expert_offsets_i32 = expert_offsets .to (torch .int32 )
941971 tokens_per_expert = expert_offsets_i32 [1 :] - expert_offsets_i32 [:- 1 ]
942- # .item() for shape computation — needed for tensor allocation
943- raw_max_M = tokens_per_expert .max ().item ()
944- max_M = ((raw_max_M + 127 ) // 128 ) * 128
945- total_tokens = expert_offsets_i32 [- 1 ].item ()
972+
973+ # Determine max_M without GPU sync on common path.
974+ # If cache exists and allocated_max_M >= total_tokens (upper bound on
975+ # any single expert's count), the buffers are guaranteed sufficient.
976+ if (hasattr (self , "_batched_cache" )
977+ and total_tokens <= self ._batched_cache .get ("allocated_max_M" , 0 )):
978+ max_M = self ._batched_cache ["allocated_max_M" ]
979+ else :
980+ # First call or total_tokens exceeds allocation: sync once
981+ raw_max_M = tokens_per_expert .max ().item ()
982+ max_M = ((raw_max_M + 127 ) // 128 ) * 128
946983
947984 x_2d = x .reshape (- 1 , K ).to (torch .bfloat16 ).contiguous ()
948985
@@ -956,7 +993,6 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
956993 # 3. Ensure persistent cached buffers exist (stable pointers for init/run)
957994 cache_key = (max_M , N , K , num_experts )
958995 if not hasattr (self , "_batched_cache" ) or self ._batched_cache .get ("key" ) != cache_key :
959- dev = x .device
960996 W = K // 16
961997 n_col_blocks = (W + 3 ) // 4
962998 n_row_blocks = (max_M + 127 ) // 128
@@ -965,13 +1001,32 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
9651001
9661002 self ._batched_cache = {
9671003 "key" : cache_key ,
1004+ "allocated_max_M" : max_M ,
9681005 "A_batched" : torch .empty (num_experts * max_M * (K // 2 ), dtype = torch .uint8 , device = dev ),
9691006 "SFA_batched" : torch .zeros (sfa_total , dtype = torch .uint8 , device = dev ),
9701007 "D_out" : torch .empty (num_experts * max_M , N , dtype = torch .bfloat16 , device = dev ),
9711008 "alpha_dev" : torch .empty (1 , dtype = torch .float32 , device = dev ),
1009+ # Pre-computed constants for scale swizzle
1010+ "sfa_per_expert" : sfa_per_expert ,
1011+ "n_row_blocks" : n_row_blocks ,
1012+ "W" : W ,
1013+ "expert_out_offsets" : torch .arange (
1014+ num_experts , dtype = torch .int32 , device = dev ,
1015+ ) * sfa_per_expert ,
9721016 }
9731017 cache = self ._batched_cache
9741018
1019+ # Ensure weighted gather buffers exist if needed
1020+ if use_weighted and num_dest_tokens is not None :
1021+ if cache .get ("gather_num_dest" ) != num_dest_tokens :
1022+ cache ["gather_workspace" ] = torch .empty (
1023+ num_dest_tokens * N , dtype = torch .float32 , device = dev ,
1024+ )
1025+ cache ["gather_output" ] = torch .empty (
1026+ num_dest_tokens , N , dtype = torch .bfloat16 , device = dev ,
1027+ )
1028+ cache ["gather_num_dest" ] = num_dest_tokens
1029+
9751030 stream = _get_tensor_stream (x_2d )
9761031
9771032 # 4. Scatter FP4 data into persistent buffer
@@ -986,29 +1041,16 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
9861041 )
9871042
9881043 # 5. Swizzle scales per-expert into persistent buffer
989- W = K // 16
990- n_col_blocks = (W + 3 ) // 4
991- n_row_blocks = (max_M + 127 ) // 128
992- sfa_per_expert = n_row_blocks * n_col_blocks * 512
993- sfa_total = num_experts * sfa_per_expert
994-
995- expert_row_offsets = expert_offsets_i32 [:- 1 ]
996- expert_M_dev = tokens_per_expert .to (torch .int32 )
997- expert_out_offsets = torch .arange (
998- num_experts , dtype = torch .int32 , device = x .device ,
999- ) * sfa_per_expert
1000-
1001- # Zero persistent SFA buffer, then swizzle into it
10021044 cache ["SFA_batched" ].zero_ ()
10031045 lib .cscale_to_blocked_batched (
10041046 get_ptr (scales_all ),
10051047 get_ptr (cache ["SFA_batched" ]),
1006- get_ptr (expert_row_offsets ),
1007- get_ptr (expert_M_dev ),
1008- get_ptr (expert_out_offsets ),
1009- ct .c_int (W ),
1048+ get_ptr (expert_offsets_i32 [: - 1 ] ),
1049+ get_ptr (tokens_per_expert ),
1050+ get_ptr (cache [ " expert_out_offsets" ] ),
1051+ ct .c_int (cache [ "W" ] ),
10101052 ct .c_int (num_experts ),
1011- ct .c_int (n_row_blocks ),
1053+ ct .c_int (cache [ " n_row_blocks" ] ),
10121054 stream ,
10131055 )
10141056
@@ -1028,18 +1070,50 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
10281070 max_M , N , K , num_experts ,
10291071 )
10301072
1031- # 8. Gather: padded per-expert BF16 → concatenated output
1032- out = moe_gather_bf16 (
1033- cache ["D_out" ].view (- 1 ), expert_offsets_i32 , max_M , N , num_experts , total_tokens ,
1034- )
1035- out = out .view (total_tokens , N )
1036-
1073+ # 8. Add bias to GEMM output (before gather, included in weighted sum)
10371074 if self .bias is not None :
1038- for i in range (num_experts ):
1039- start = expert_offsets_i32 [i ].item ()
1040- end = expert_offsets_i32 [i + 1 ].item ()
1041- if end > start :
1042- out [start :end ] = out [start :end ] + self .bias [i ].to (out .dtype )
1075+ D_out_3d = cache ["D_out" ].view (num_experts , max_M , N )
1076+ D_out_3d += self .bias .unsqueeze (1 ).to (D_out_3d .dtype )
1077+
1078+ # 9. Gather: padded per-expert → output
1079+ if use_weighted and num_dest_tokens is not None :
1080+ # Derive expert_ids and slot_ids from expert_offsets (all on GPU)
1081+ expert_ids = torch .repeat_interleave (
1082+ torch .arange (num_experts , device = dev , dtype = torch .int32 ),
1083+ tokens_per_expert ,
1084+ )
1085+ starts_expanded = torch .repeat_interleave (
1086+ expert_offsets_i32 [:- 1 ], tokens_per_expert ,
1087+ )
1088+ slot_ids = (
1089+ torch .arange (total_tokens , device = dev , dtype = torch .int32 )
1090+ - starts_expanded
1091+ )
1092+
1093+ # Fused weighted gather: gather + weight + FP32 accumulate + BF16 convert
1094+ lib .cmoe_weighted_gather_bf16 (
1095+ get_ptr (cache ["D_out" ]),
1096+ get_ptr (cache ["gather_output" ]),
1097+ get_ptr (cache ["gather_workspace" ]),
1098+ get_ptr (token_ids .to (torch .int32 )),
1099+ get_ptr (expert_ids ),
1100+ get_ptr (slot_ids ),
1101+ get_ptr (gating_weights .to (torch .float32 )),
1102+ ct .c_int (total_tokens ),
1103+ ct .c_int (num_dest_tokens ),
1104+ ct .c_int (max_M ),
1105+ ct .c_int (N ),
1106+ stream ,
1107+ )
1108+ out = cache ["gather_output" ]
1109+ else :
1110+ # Unweighted gather (backwards compatible path)
1111+ from bitsandbytes .functional import moe_gather_bf16
1112+ out = moe_gather_bf16 (
1113+ cache ["D_out" ].view (- 1 ), expert_offsets_i32 ,
1114+ max_M , N , num_experts , total_tokens ,
1115+ )
1116+ out = out .view (total_tokens , N )
10431117
10441118 return out .to (inp_dtype )
10451119
0 commit comments