We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f47c3f4 commit 7709ab5Copy full SHA for 7709ab5
1 file changed
src/maxtext/kernels/ragged/ragged_sort.py
@@ -125,8 +125,9 @@ def _ring_ragged_sort_bwd(res, g_out):
125
# only iterates over the populated prefix, so we hand it the mask directly
126
# rather than materializing a (mostly-zero) dense buffer ourselves.
127
n = topk_argsort_revert_indices.shape[0]
128
- pos = jnp.arange(n)
129
- valid_rows_mask = (pos >= shard_output_start) & (pos < shard_output_end)
+ valid_rows_mask = (topk_argsort_revert_indices >= shard_output_start) & (
+ topk_argsort_revert_indices < shard_output_end
130
+ )
131
# The forward scatter-add over `token_indices_sorted` is equivalent to a
132
# gather-reduce: each input token has exactly `topk` contributions located
133
# at sorted positions `topk_argsort_revert_indices[t*topk:(t+1)*topk]`.
0 commit comments