@@ -1171,9 +1171,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11711171 elif self .config .attention == "vllm_rpa" :
11721172 return group_sizes
11731173 else :
1174+ num_groups = group_sizes .shape [0 ]
11741175 return tokamax .RaggedDotGroupSizes (
11751176 group_sizes ,
1176- (inputs .shape [0 ] // kernel . shape [ 0 ] ,) * kernel . shape [ 0 ] ,
1177+ (inputs .shape [0 ] // num_groups ,) * num_groups ,
11771178 )
11781179
11791180 def get_quantization_dtypes ():
@@ -1184,7 +1185,7 @@ def get_quantization_dtypes():
11841185 rhs_quantize_dtype = quant_dg .fwd .dg_quantizer .rhs .numerics .get_dtype ()
11851186 return lhs_quantize_dtype , rhs_quantize_dtype
11861187
1187- def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes ):
1188+ def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes , group_offset ):
11881189 def extract_vma (tensor ):
11891190 # Parses the varying mesh axes from JAX's type string for a tensor inside shard_map.
11901191 # jax.typeof(t) renders as e.g. 'f32[128,256]{V:(expert, fsdp)}'; this extracts
@@ -1219,6 +1220,7 @@ def extract_vma(tensor):
12191220 group_sizes = group_sizes ,
12201221 preferred_element_type = self .dtype ,
12211222 tiling = tiling ,
1223+ group_offset = group_offset ,
12221224 lhs_quantize_dtype = lhs_quantize_dtype ,
12231225 rhs_quantize_dtype = rhs_quantize_dtype ,
12241226 use_qwix_quantization = self .config .use_qwix_quantization ,
@@ -1235,6 +1237,7 @@ def extract_vma(tensor):
12351237 precision = jax .lax .Precision .DEFAULT ,
12361238 preferred_element_type = self .dtype ,
12371239 implementation = "mosaic" ,
1240+ group_offset = group_offset ,
12381241 )
12391242 elif self .config .megablox : # Older forked megablox
12401243 output = mblx .gmm (
@@ -1243,6 +1246,7 @@ def extract_vma(tensor):
12431246 group_sizes = group_sizes ,
12441247 preferred_element_type = self .dtype ,
12451248 tiling = tiling ,
1249+ group_offset = group_offset ,
12461250 lhs_quantize_dtype = lhs_quantize_dtype ,
12471251 rhs_quantize_dtype = rhs_quantize_dtype ,
12481252 use_qwix_quantization = self .config .use_qwix_quantization ,
@@ -1382,11 +1386,6 @@ def route(x, logits, pre_bias_logits, rngs):
13821386 rngs = rngs ,
13831387 )
13841388
1385- # Filter down to the group sizes that apply to only the experts in the
1386- # current shard.
1387- group_sizes = group_sizes [:num_experts_per_shard ]
1388- mask = jnp .arange (x .shape [0 ]) < jnp .sum (group_sizes )
1389- x = jnp .where (mask [:, None ], x , 0 )
13901389 else :
13911390 x , sorted_selected_experts , weights , group_sizes , selected_experts , lb_loss , bias_updates = self .permute (
13921391 x , logits , pre_bias_logits , self .config .use_custom_sort_vjp , rngs
@@ -1559,7 +1558,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
15591558 if self .config .mlp_bias :
15601559 w0_bias , w1_bias , wo_bias = self .transform_bias (routing .selected_experts , w0_bias , w1_bias , wo_bias )
15611560
1562- gmm_fn = functools .partial (gmm , group_sizes = routing .group_sizes , expert_assignments = routing .selected_experts )
1561+ num_ep = self .get_expert_parallelism_size ()
1562+ num_experts_per_shard = self .config .num_experts // num_ep
1563+ if self .config .use_ragged_sort and self .config .use_ring_of_experts :
1564+ experts_start = route_metadata .expert_shard_id * num_experts_per_shard
1565+ else :
1566+ experts_start = 0
1567+
1568+ gmm_fn = functools .partial (
1569+ gmm , group_sizes = routing .group_sizes , expert_assignments = routing .selected_experts , group_offset = experts_start
1570+ )
15631571 intermediate_layer = gmm_up (x , w0 , w1 , w0_bias , w1_bias , gmm_fn , weight_gather )
15641572
15651573 wo_gather_axes , wo_tile_size = get_wo_gmm_params ()
@@ -1578,10 +1586,6 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
15781586 intermediate_output = adc .checkpoint_name (adc .checkpoint_name (intermediate_output , "mlpwo" ), "moe_mlpwo" )
15791587
15801588 if self .config .use_ring_of_experts :
1581- # Set the outputs of tokens which were not processed to 0.
1582- mask = jnp .arange (intermediate_output .shape [0 ]) < jnp .sum (routing .group_sizes )
1583- intermediate_output = jnp .where (mask [:, None ], intermediate_output , 0 )
1584-
15851589 # Unsort and deduplicate the outputs locally.
15861590 output = self .unpermute (
15871591 intermediate_output ,
0 commit comments