Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/maxtext/kernels/ragged/ragged_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
shard_output_end,
)

valid_mask = (jnp.arange(x.shape[0]) >= shard_output_start) & (jnp.arange(x.shape[0]) < shard_output_end)
x = jnp.where(valid_mask[:, None], x, 0.0)

out = (x, group_sizes_local, topk_argsort_revert_indices)

res = (
Expand Down Expand Up @@ -125,8 +122,9 @@ def _ring_ragged_sort_bwd(res, g_out):
# only iterates over the populated prefix, so we hand it the mask directly
# rather than materializing a (mostly-zero) dense buffer ourselves.
n = topk_argsort_revert_indices.shape[0]
pos = jnp.arange(n)
valid_rows_mask = (pos >= shard_output_start) & (pos < shard_output_end)
valid_rows_mask = (topk_argsort_revert_indices >= shard_output_start) & (
topk_argsort_revert_indices < shard_output_end
)
# The forward scatter-add over `token_indices_sorted` is equivalent to a
# gather-reduce: each input token has exactly `topk` contributions located
# at sorted positions `topk_argsort_revert_indices[t*topk:(t+1)*topk]`.
Expand Down Expand Up @@ -209,7 +207,6 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort
topk_argsort_revert_indices,
shard_output_start,
shard_output_end,
sorted_tokens_local.shape,
)

return out, res
Expand All @@ -236,9 +233,8 @@ def _ring_ragged_unsort_bwd(res, g_out):
range of ``j``. The simpler equivalent: gather of g_hidden_states_local
using the inverse permutation, masked.
"""
topk_argsort_revert_indices, shard_output_start, shard_output_end, sorted_tokens_local_shape = res
topk_argsort_revert_indices, shard_output_start, shard_output_end = res
g_hidden_states_local = g_out
num_rows = sorted_tokens_local_shape[0]

# We want: g_sorted_tokens[j] = g_hidden_states_local[i] where revert[i]=j.
# Build the inverse permutation idx_inv such that idx_inv[j] = i.
Expand All @@ -250,12 +246,6 @@ def _ring_ragged_unsort_bwd(res, g_out):
shard_output_start,
shard_output_end,
)
# Outside [start, end), positions must be zero — which the ragged_gather
# already guarantees because untouched output rows are uninitialized; we
# explicitly zero them.
pos = jnp.arange(num_rows)
valid = (pos >= shard_output_start) & (pos < shard_output_end)
grad_sorted_tokens = jnp.where(valid[:, None], grad_sorted_tokens, jnp.zeros_like(grad_sorted_tokens))
return grad_sorted_tokens, None, None

_ring_ragged_unsort.defvjp(_ring_ragged_unsort_fwd, _ring_ragged_unsort_bwd)
Expand Down
28 changes: 16 additions & 12 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,9 +1171,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
elif self.config.attention == "vllm_rpa":
return group_sizes
else:
num_groups = group_sizes.shape[0]
return tokamax.RaggedDotGroupSizes(
group_sizes,
(inputs.shape[0] // kernel.shape[0],) * kernel.shape[0],
(inputs.shape[0] // num_groups,) * num_groups,
)

def get_quantization_dtypes():
Expand All @@ -1184,7 +1185,7 @@ def get_quantization_dtypes():
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
return lhs_quantize_dtype, rhs_quantize_dtype

def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, group_offset):
def extract_vma(tensor):
# Parses the varying mesh axes from JAX's type string for a tensor inside shard_map.
# jax.typeof(t) renders as e.g. 'f32[128,256]{V:(expert, fsdp)}'; this extracts
Expand Down Expand Up @@ -1219,6 +1220,7 @@ def extract_vma(tensor):
group_sizes=group_sizes,
preferred_element_type=self.dtype,
tiling=tiling,
group_offset=group_offset,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=self.config.use_qwix_quantization,
Expand All @@ -1235,6 +1237,7 @@ def extract_vma(tensor):
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
group_offset=group_offset,
)
elif self.config.megablox: # Older forked megablox
output = mblx.gmm(
Expand All @@ -1243,6 +1246,7 @@ def extract_vma(tensor):
group_sizes=group_sizes,
preferred_element_type=self.dtype,
tiling=tiling,
group_offset=group_offset,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=self.config.use_qwix_quantization,
Expand Down Expand Up @@ -1382,11 +1386,6 @@ def route(x, logits, pre_bias_logits, rngs):
rngs=rngs,
)

# Filter down to the group sizes that apply to only the experts in the
# current shard.
group_sizes = group_sizes[:num_experts_per_shard]
mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes)
x = jnp.where(mask[:, None], x, 0)
else:
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs
Expand Down Expand Up @@ -1559,7 +1558,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
if self.config.mlp_bias:
w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias)

gmm_fn = functools.partial(gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts)
num_ep = self.get_expert_parallelism_size()
num_experts_per_shard = self.config.num_experts // num_ep
if self.config.use_ragged_sort and self.config.use_ring_of_experts:
experts_start = route_metadata.expert_shard_id * num_experts_per_shard
else:
experts_start = 0

gmm_fn = functools.partial(
gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts, group_offset=experts_start
)
intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather)

wo_gather_axes, wo_tile_size = get_wo_gmm_params()
Expand All @@ -1578,10 +1586,6 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
intermediate_output = adc.checkpoint_name(adc.checkpoint_name(intermediate_output, "mlpwo"), "moe_mlpwo")

if self.config.use_ring_of_experts:
# Set the outputs of tokens which were not processed to 0.
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(routing.group_sizes)
intermediate_output = jnp.where(mask[:, None], intermediate_output, 0)

# Unsort and deduplicate the outputs locally.
output = self.unpermute(
intermediate_output,
Expand Down
Loading
Loading