Skip to content

Commit 9172c06

Browse files
committed
add fan out for scatter bwd
1 parent 5f3dc2b commit 9172c06

3 files changed

Lines changed: 124 additions & 31 deletions

File tree

src/maxtext/kernels/ragged/ragged_gather.py

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,22 @@ def main_kernel(
4141
end_ref: jax.Ref,
4242
in_hbm_ref: jax.Ref,
4343
indices_hbm_ref: jax.Ref,
44+
weights_hbm_ref: jax.Ref,
4445
# Outputs.
4546
out_hbm_ref: jax.Ref,
4647
# Scratch.
4748
start_vmem_ref: jax.Ref,
4849
end_vmem_ref: jax.Ref,
4950
out_vmem_ref: jax.Ref,
5051
indices_vmem_ref: jax.Ref,
52+
weights_vmem_ref: jax.Ref,
5153
sem_ref: jax.Ref,
5254
*,
5355
core_axis_name: str,
5456
subcore_axis_name: str,
57+
has_weights: bool,
5558
):
56-
"""Core ragged gather operation"""
59+
"""Core ragged gather operation with per-row weighting."""
5760
tpu_info = pltpu.get_tpu_info()
5861
sc_info = tpu_info.sparse_core
5962
assert sc_info is not None
@@ -109,9 +112,16 @@ def _():
109112
indices_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)],
110113
indices_vmem_ref,
111114
)
115+
if has_weights:
116+
pltpu.sync_copy(
117+
weights_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)],
118+
weights_vmem_ref,
119+
)
112120

113121
# HBM to VMEM transfer.
114122
indices = indices_vmem_ref[...]
123+
if has_weights:
124+
weights = weights_vmem_ref[...]
115125

116126
dtype = out_hbm_ref.dtype
117127
dtype_bits = jax.dtypes.itemsize_bits(dtype)
@@ -189,6 +199,50 @@ def dma_write_loop(col_vmem_start):
189199
row_dst = row_src // packing
190200
out_vmem_ref[row_dst, col_slice] = out
191201

202+
# Apply per-row weights after unpacking if needed.
203+
# For packing == 1 (float32), we can apply directly.
204+
# For packing > 1 (bf16), the data is already packed; we apply below.
205+
if has_weights:
206+
if packing == 1:
207+
# float32 path: data is already in float32 layout, one row per sublane.
208+
for col_compute_offset in range(0, num_lanes, num_simd_lanes):
209+
col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes)
210+
for row_vmem in range(num_simd_lanes):
211+
data = out_vmem_ref[row_vmem, col_slice]
212+
data_f32 = jax.lax.bitcast_convert_type(data, jnp.float32)
213+
data_f32 = data_f32 * weights[row_vmem]
214+
out_vmem_ref[row_vmem, col_slice] = jax.lax.bitcast_convert_type(data_f32, jnp.uint32)
215+
else:
216+
# bf16 path: data is packed, packing=2. Each packed row contains 2
217+
# bf16 values from consecutive source rows. We need to unpack each
218+
# bf16 to float32, multiply by its weight, then repack.
219+
for col_compute_offset in range(0, num_lanes, num_simd_lanes):
220+
col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes)
221+
for row_dst in range(num_simd_lanes // packing):
222+
packed_data = out_vmem_ref[row_dst, col_slice]
223+
result = jnp.zeros_like(packed_data)
224+
for sub in range(packing):
225+
row_src = row_dst * packing + sub
226+
# Extract the sub-element.
227+
shift_right = sub * dtype_bits
228+
shift_left = sub * dtype_bits
229+
elem = jnp.bitwise_right_shift(packed_data, shift_right)
230+
elem = jnp.bitwise_and(elem, 2**dtype_bits - 1)
231+
# Convert bf16 bits to float32 for weighting.
232+
# bf16 is stored in the lower 16 bits; shift to upper 16 for
233+
# bitcast to float32.
234+
elem_f32 = jnp.bitwise_left_shift(elem, 16)
235+
elem_f32 = jax.lax.bitcast_convert_type(elem_f32, jnp.float32)
236+
elem_f32 = elem_f32 * weights[row_src]
237+
# Convert back: bitcast float32 -> uint32, shift right 16 to
238+
# get bf16 bits, then shift left to target position.
239+
elem_u32 = jax.lax.bitcast_convert_type(elem_f32, jnp.uint32)
240+
elem_bf16 = jnp.bitwise_right_shift(elem_u32, 16)
241+
elem_bf16 = jnp.bitwise_and(elem_bf16, 2**dtype_bits - 1)
242+
elem_bf16 = jnp.bitwise_left_shift(elem_bf16, shift_left)
243+
result = jnp.bitwise_or(result, elem_bf16)
244+
out_vmem_ref[row_dst, col_slice] = result
245+
192246
# Start dma write.
193247
for row_vmem in range(num_simd_lanes // packing):
194248
row_hbm = row_tile_start // packing + row_vmem
@@ -234,9 +288,31 @@ def calculate_col_size(hidden_size: int) -> int:
234288
return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes
235289

236290

237-
@jax.jit
238-
def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.Array) -> jax.Array:
239-
"""Perform gather on indices within dynamic array start and end."""
291+
@functools.partial(jax.jit, static_argnames=("has_weights",))
292+
def ragged_gather(
293+
x: jax.Array,
294+
indices: jax.Array,
295+
start: jax.Array,
296+
end: jax.Array,
297+
weights: jax.Array | None = None,
298+
has_weights: bool = False,
299+
) -> jax.Array:
300+
"""Perform gather on indices within dynamic array start and end.
301+
302+
Args:
303+
x: 2D input array of shape ``(input_size, hidden_size)``.
304+
indices: 1D array of gather indices.
305+
start: Scalar or 1D array indicating the start of the valid range.
306+
end: Scalar or 1D array indicating the end of the valid range.
307+
weights: Optional 1D array of per-row weights. When provided, each
308+
gathered row is multiplied by its corresponding weight inside the
309+
kernel, avoiding an extra HBM read-write pass.
310+
has_weights: Static bool flag indicating whether ``weights`` should be
311+
applied. Must be ``True`` when ``weights`` is not ``None``.
312+
313+
Returns:
314+
Gathered output of shape ``(indices_size, hidden_size)``.
315+
"""
240316

241317
assert x.ndim == 2, "Ragged gather only supports 2d inputs."
242318
assert indices.ndim == 1, "Ragged gather only supports 1d indices."
@@ -257,7 +333,10 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
257333
sc_info = pltpu.get_tpu_info().sparse_core
258334
if sc_info is None:
259335
# Sparse core is not available. Fallback to regular gather.
260-
return x[indices]
336+
out = x[indices]
337+
if has_weights:
338+
out = out * weights[:, None]
339+
return out
261340

262341
hidden_size = x.shape[-1]
263342
out_size = indices.size
@@ -271,6 +350,12 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
271350
out_pad_size = pl.cdiv(out_size, block_size) * block_size - out_size
272351
indices = jnp.pad(indices, ((0, out_pad_size)))
273352

353+
if has_weights:
354+
weights = jnp.pad(weights, ((0, out_pad_size)), constant_values=1.0)
355+
else:
356+
# Provide a dummy weights array; the kernel won't use it.
357+
weights = jnp.ones((out_size + out_pad_size,), dtype=jnp.float32)
358+
274359
aligned_hidden_size = pl.cdiv(hidden_size, col_size) * col_size
275360

276361
vector_mesh = plsc.VectorSubcoreMesh(
@@ -284,6 +369,7 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
284369
main_kernel,
285370
core_axis_name=vector_mesh.core_axis_name,
286371
subcore_axis_name=vector_mesh.subcore_axis_name,
372+
has_weights=has_weights,
287373
),
288374
compiler_params=pltpu.CompilerParams(
289375
use_tc_tiling_on_sc=True,
@@ -298,7 +384,8 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
298384
pltpu.VMEM((num_simd_lanes,), jnp.int32),
299385
pltpu.VMEM((num_simd_lanes, col_size), jnp.uint32),
300386
pltpu.VMEM((num_simd_lanes,), jnp.int32),
387+
pltpu.VMEM((num_simd_lanes,), jnp.float32),
301388
pltpu.SemaphoreType.DMA((2,)),
302389
],
303390
},
304-
)(start, end, x, indices)[:out_size, :hidden_size]
391+
)(start, end, x, indices, weights)[:out_size, :hidden_size]

src/maxtext/kernels/ragged/ragged_sort.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _ring_ragged_sort(hidden_states_local, topk_indices_local):
7171
"""Sort and gather activations to different EP shards."""
7272
return _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local)[0]
7373

74-
@jax.named_scope("ragged-gather-fwd")
74+
@jax.named_scope("ragged-sort-fwd")
7575
def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
7676
"""Sort and gather activations forward pass."""
7777

@@ -136,7 +136,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
136136

137137
return out, res
138138

139-
@jax.named_scope("ragged-gather-bwd")
139+
@jax.named_scope("ragged-sort-bwd")
140140
def _ring_ragged_sort_bwd(res, g_out):
141141
"""Backward pass for the gather: a Pallas SC ragged gather reduce.
142142
The forward gathers ``hidden_states_local[token_indices_sorted[i]]`` into
@@ -248,7 +248,7 @@ def _ring_ragged_unsort(sorted_tokens_local, group_sizes_local, topk_argsort_rev
248248
sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat
249249
)[0]
250250

251-
@jax.named_scope("ragged-scatter-fwd")
251+
@jax.named_scope("ragged-unsort-fwd")
252252
def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat):
253253
"""Executes unsorting sending tokens back."""
254254
group_offsets = jnp.cumulative_sum(group_sizes_local, include_initial=True)
@@ -309,7 +309,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort
309309

310310
return out, res
311311

312-
@jax.named_scope("ragged-scatter-bwd")
312+
@jax.named_scope("ragged-unsort-bwd")
313313
def _ring_ragged_unsort_bwd(res, g_out):
314314
"""Backward pass for the scatter with routing weights.
315315
@@ -330,38 +330,44 @@ def _ring_ragged_unsort_bwd(res, g_out):
330330
) = res
331331
g_hidden_states_local = g_out
332332

333-
# Expand g_out from [num_tokens] to [num_tokens * topk] by repeating each
334-
# row topk times, so that g_expanded[i] = g_out[i // topk].
335-
g_expanded = jnp.repeat(g_hidden_states_local, topk, axis=0)
336-
337-
# Apply per-slot routing weights: g_weighted[i] = w[i] * g_out[i // topk]
338-
g_weighted = g_expanded * topk_weights_flat[:, None]
339-
340333
n = topk_argsort_revert_indices.shape[0]
334+
# Build the inverse permutation idx_inv such that idx_inv[j] = i
335+
# where revert[i] = j.
336+
idx_inv = jnp.argsort(topk_argsort_revert_indices)
341337

342338
# Handle the same two buffering modes for backward pass.
339+
# We let ragged_gather do both the fan-out (by indexing into the
340+
# un-expanded g_hidden_states_local via idx_inv // topk) and the
341+
# per-slot weight application (via the fused weights parameter),
342+
# avoiding an extra HBM read-write pass.
343343
if buffer_size >= n:
344-
# We want: g_sorted_tokens[j] = g_weighted[i] where revert[i]=j.
345-
# Build the inverse permutation idx_inv such that idx_inv[j] = i.
346-
idx_inv = jnp.argsort(topk_argsort_revert_indices)
347-
# Because revert is a permutation, gathering with idx_inv reorders correctly.
344+
# ragged_gather fans out g_hidden_states_local by reading the same row
345+
# multiple times when idx_inv // topk maps multiple positions to it.
346+
# Per-slot routing weights are applied inside the kernel.
347+
weight_for_sorted = topk_weights_flat[idx_inv]
348348
grad_sorted_tokens = ragged_gather(
349-
g_weighted,
350-
idx_inv,
349+
g_hidden_states_local,
350+
idx_inv // topk,
351351
shard_output_start[None],
352352
shard_output_end[None],
353+
weights=weight_for_sorted,
354+
has_weights=True,
353355
)
354356
else:
355357
# Slice the inverse permutation to match the packed local buffer.
356-
idx_inv = jnp.argsort(topk_argsort_revert_indices)
357358
padded_idx_inv = jnp.pad(idx_inv, (0, buffer_size))
358359
sliced_idx_inv = jax.lax.dynamic_slice_in_dim(padded_idx_inv, shard_output_start, buffer_size, axis=0)
359360
gather_end = jnp.minimum(shard_output_end - shard_output_start, buffer_size)
361+
# Slice the per-slot routing weights to match the packed local buffer.
362+
padded_weights = jnp.pad(topk_weights_flat[idx_inv], (0, buffer_size))
363+
sliced_weights = jax.lax.dynamic_slice_in_dim(padded_weights, shard_output_start, buffer_size, axis=0)
360364
grad_sorted_tokens = ragged_gather(
361-
g_weighted,
362-
sliced_idx_inv,
365+
g_hidden_states_local,
366+
sliced_idx_inv // topk,
363367
jnp.int32(0)[None],
364368
gather_end[None],
369+
weights=sliced_weights,
370+
has_weights=True,
365371
)
366372
return grad_sorted_tokens, None, None, None
367373

@@ -409,7 +415,7 @@ def a2a_ragged_sort(inputs, sort_indices, valid_end):
409415
def _a2a_ragged_sort(inputs, sort_indices, valid_end):
410416
return _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end)[0]
411417

412-
@jax.named_scope("local-ragged-gather-fwd")
418+
@jax.named_scope("local-ragged-sort-fwd")
413419
def _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end):
414420
start = jnp.int32(0)
415421
end = valid_end.astype(jnp.int32) if hasattr(valid_end, "astype") else jnp.int32(valid_end)
@@ -420,7 +426,7 @@ def _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end):
420426
res = (sort_indices, end, inputs.shape)
421427
return out, res
422428

423-
@jax.named_scope("local-ragged-gather-bwd")
429+
@jax.named_scope("local-ragged-sort-bwd")
424430
def _a2a_ragged_sort_bwd(res, g_out):
425431
sort_indices, end, _ = res
426432
n = sort_indices.shape[0]
@@ -474,7 +480,7 @@ def a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end):
474480
def _a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end):
475481
return _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end)[0]
476482

477-
@jax.named_scope("local-ragged-scatter-fwd")
483+
@jax.named_scope("local-ragged-unsort-fwd")
478484
def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end):
479485
start = jnp.int32(0)
480486
end = valid_end.astype(jnp.int32) if hasattr(valid_end, "astype") else jnp.int32(valid_end)
@@ -490,7 +496,7 @@ def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end):
490496
res = (revert_indices, end, sorted_tokens.shape, start)
491497
return out, res
492498

493-
@jax.named_scope("local-ragged-scatter-bwd")
499+
@jax.named_scope("local-ragged-unsort-bwd")
494500
def _a2a_ragged_unsort_bwd(res, g_out):
495501
revert_indices, end, sorted_tokens_shape, start = res
496502
# g_sorted_tokens[revert_indices[i]] = g_out[i] for i in [0, end).

tests/unit/moe_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _build_cfg(use_ragged_sort: bool):
618618
enable_checkpointing=False,
619619
model_name="mixtral-8x7b",
620620
override_model_config=True,
621-
base_emb_dim=256,
621+
base_emb_dim=2048, # we want emb dim being multiple of 1024 for fully using the kernel
622622
base_mlp_dim=256,
623623
base_moe_mlp_dim=256,
624624
dtype="bfloat16",

0 commit comments

Comments
 (0)