Skip to content

Commit 39b5736

Browse files
committed
remove masks and update gmm function in moe
1 parent 334f936 commit 39b5736

2 files changed

Lines changed: 17 additions & 24 deletions

File tree

src/maxtext/kernels/ragged/ragged_sort.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
9292
shard_output_end,
9393
)
9494

95-
valid_mask = (jnp.arange(x.shape[0]) >= shard_output_start) & (jnp.arange(x.shape[0]) < shard_output_end)
96-
x = jnp.where(valid_mask[:, None], x, 0.0)
97-
9895
out = (x, group_sizes_local, topk_argsort_revert_indices)
9996

10097
res = (
@@ -210,7 +207,6 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort
210207
topk_argsort_revert_indices,
211208
shard_output_start,
212209
shard_output_end,
213-
sorted_tokens_local.shape,
214210
)
215211

216212
return out, res
@@ -237,9 +233,8 @@ def _ring_ragged_unsort_bwd(res, g_out):
237233
range of ``j``. The simpler equivalent: gather of g_hidden_states_local
238234
using the inverse permutation, masked.
239235
"""
240-
topk_argsort_revert_indices, shard_output_start, shard_output_end, sorted_tokens_local_shape = res
236+
topk_argsort_revert_indices, shard_output_start, shard_output_end = res
241237
g_hidden_states_local = g_out
242-
num_rows = sorted_tokens_local_shape[0]
243238

244239
# We want: g_sorted_tokens[j] = g_hidden_states_local[i] where revert[i]=j.
245240
# Build the inverse permutation idx_inv such that idx_inv[j] = i.
@@ -251,12 +246,6 @@ def _ring_ragged_unsort_bwd(res, g_out):
251246
shard_output_start,
252247
shard_output_end,
253248
)
254-
# Outside [start, end), positions must be zero — which the ragged_gather
255-
# already guarantees because untouched output rows are uninitialized; we
256-
# explicitly zero them.
257-
pos = jnp.arange(num_rows)
258-
valid = (pos >= shard_output_start) & (pos < shard_output_end)
259-
grad_sorted_tokens = jnp.where(valid[:, None], grad_sorted_tokens, jnp.zeros_like(grad_sorted_tokens))
260249
return grad_sorted_tokens, None, None
261250

262251
_ring_ragged_unsort.defvjp(_ring_ragged_unsort_fwd, _ring_ragged_unsort_bwd)

src/maxtext/layers/moe.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,9 +1171,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11711171
elif self.config.attention == "vllm_rpa":
11721172
return group_sizes
11731173
else:
1174+
num_groups = group_sizes.shape[0]
11741175
return tokamax.RaggedDotGroupSizes(
11751176
group_sizes,
1176-
(inputs.shape[0] // kernel.shape[0],) * kernel.shape[0],
1177+
(inputs.shape[0] // num_groups,) * num_groups,
11771178
)
11781179

11791180
def get_quantization_dtypes():
@@ -1184,7 +1185,7 @@ def get_quantization_dtypes():
11841185
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
11851186
return lhs_quantize_dtype, rhs_quantize_dtype
11861187

1187-
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
1188+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, group_offset):
11881189
def extract_vma(tensor):
11891190
# Parses the varying mesh axes from JAX's type string for a tensor inside shard_map.
11901191
# jax.typeof(t) renders as e.g. 'f32[128,256]{V:(expert, fsdp)}'; this extracts
@@ -1219,6 +1220,7 @@ def extract_vma(tensor):
12191220
group_sizes=group_sizes,
12201221
preferred_element_type=self.dtype,
12211222
tiling=tiling,
1223+
group_offset=group_offset,
12221224
lhs_quantize_dtype=lhs_quantize_dtype,
12231225
rhs_quantize_dtype=rhs_quantize_dtype,
12241226
use_qwix_quantization=self.config.use_qwix_quantization,
@@ -1235,6 +1237,7 @@ def extract_vma(tensor):
12351237
precision=jax.lax.Precision.DEFAULT,
12361238
preferred_element_type=self.dtype,
12371239
implementation="mosaic",
1240+
group_offset=group_offset,
12381241
)
12391242
elif self.config.megablox: # Older forked megablox
12401243
output = mblx.gmm(
@@ -1243,6 +1246,7 @@ def extract_vma(tensor):
12431246
group_sizes=group_sizes,
12441247
preferred_element_type=self.dtype,
12451248
tiling=tiling,
1249+
group_offset=group_offset,
12461250
lhs_quantize_dtype=lhs_quantize_dtype,
12471251
rhs_quantize_dtype=rhs_quantize_dtype,
12481252
use_qwix_quantization=self.config.use_qwix_quantization,
@@ -1382,11 +1386,6 @@ def route(x, logits, pre_bias_logits, rngs):
13821386
rngs=rngs,
13831387
)
13841388

1385-
# Filter down to the group sizes that apply to only the experts in the
1386-
# current shard.
1387-
group_sizes = group_sizes[:num_experts_per_shard]
1388-
mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes)
1389-
x = jnp.where(mask[:, None], x, 0)
13901389
else:
13911390
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
13921391
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs
@@ -1559,7 +1558,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
15591558
if self.config.mlp_bias:
15601559
w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias)
15611560

1562-
gmm_fn = functools.partial(gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts)
1561+
num_ep = self.get_expert_parallelism_size()
1562+
num_experts_per_shard = self.config.num_experts // num_ep
1563+
if self.config.use_ragged_sort and self.config.use_ring_of_experts:
1564+
experts_start = route_metadata.expert_shard_id * num_experts_per_shard
1565+
else:
1566+
experts_start = 0
1567+
1568+
gmm_fn = functools.partial(
1569+
gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts, group_offset=experts_start
1570+
)
15631571
intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather)
15641572

15651573
wo_gather_axes, wo_tile_size = get_wo_gmm_params()
@@ -1578,10 +1586,6 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
15781586
intermediate_output = adc.checkpoint_name(adc.checkpoint_name(intermediate_output, "mlpwo"), "moe_mlpwo")
15791587

15801588
if self.config.use_ring_of_experts:
1581-
# Set the outputs of tokens which were not processed to 0.
1582-
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(routing.group_sizes)
1583-
intermediate_output = jnp.where(mask[:, None], intermediate_output, 0)
1584-
15851589
# Unsort and deduplicate the outputs locally.
15861590
output = self.unpermute(
15871591
intermediate_output,

0 commit comments

Comments
 (0)