Skip to content

Commit 1fbbdbb

Browse files
committed
update gather reduce kernel aligning with tpu inference
1 parent 493fba6 commit 1fbbdbb

3 files changed

Lines changed: 194 additions & 84 deletions

File tree

src/maxtext/kernels/ragged/ragged_gather.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
267267

268268
aligned_hidden_size = pl.cdiv(hidden_size, col_size) * col_size
269269

270+
# Cost estimate for the ragged gather kernel.
271+
# This is a pure data-movement kernel (gather/scatter), so flops = 0.
272+
# bytes_accessed accounts for reading input rows and writing output rows,
273+
# plus reading the indices array and start/end scalars.
274+
dtype_bytes = jax.dtypes.itemsize_bits(dtype) // 8
275+
padded_out_size = out_size + (pl.cdiv(out_size, block_size) * block_size - out_size)
276+
bytes_accessed = (
277+
padded_out_size * aligned_hidden_size * dtype_bytes # read from x
278+
+ padded_out_size * aligned_hidden_size * dtype_bytes # write to output
279+
+ indices.size * jnp.dtype(jnp.int32).itemsize # read indices
280+
+ 2 * jnp.dtype(jnp.int32).itemsize # read start and end
281+
)
282+
cost_estimate = pl.CostEstimate(flops=0, transcendentals=0, bytes_accessed=bytes_accessed)
283+
270284
vector_mesh = plsc.VectorSubcoreMesh(
271285
num_cores=sc_info.num_cores,
272286
num_subcores=sc_info.num_subcores,
@@ -285,6 +299,7 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
285299
),
286300
mesh=vector_mesh,
287301
name="sc_ragged_gather",
302+
cost_estimate=cost_estimate,
288303
**{
289304
_OUT_KW: jax.ShapeDtypeStruct((out_size + out_pad_size, aligned_hidden_size), dtype),
290305
_SCRATCH_KW: [

0 commit comments

Comments
 (0)