Skip to content

Commit 7709ab5

Browse files
RissyRanNuojCheng
authored andcommitted
Update ragged gather reduce to offload two sparce cores
1 parent f47c3f4 commit 7709ab5

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/maxtext/kernels/ragged/ragged_sort.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ def _ring_ragged_sort_bwd(res, g_out):
125125
# only iterates over the populated prefix, so we hand it the mask directly
126126
# rather than materializing a (mostly-zero) dense buffer ourselves.
127127
n = topk_argsort_revert_indices.shape[0]
128-
pos = jnp.arange(n)
129-
valid_rows_mask = (pos >= shard_output_start) & (pos < shard_output_end)
128+
valid_rows_mask = (topk_argsort_revert_indices >= shard_output_start) & (
129+
topk_argsort_revert_indices < shard_output_end
130+
)
130131
# The forward scatter-add over `token_indices_sorted` is equivalent to a
131132
# gather-reduce: each input token has exactly `topk` contributions located
132133
# at sorted positions `topk_argsort_revert_indices[t*topk:(t+1)*topk]`.

0 commit comments

Comments
 (0)