@@ -267,6 +267,20 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
267267
268268 aligned_hidden_size = pl .cdiv (hidden_size , col_size ) * col_size
269269
270+ # Cost estimate for the ragged gather kernel.
271+ # This is a pure data-movement kernel (gather/scatter), so flops = 0.
272+ # bytes_accessed accounts for reading input rows and writing output rows,
273+ # plus reading the indices array and start/end scalars.
274+ dtype_bytes = jax .dtypes .itemsize_bits (dtype ) // 8
275+ padded_out_size = out_size + (pl .cdiv (out_size , block_size ) * block_size - out_size )
276+ bytes_accessed = (
277+ padded_out_size * aligned_hidden_size * dtype_bytes # read from x
278+ + padded_out_size * aligned_hidden_size * dtype_bytes # write to output
279+ + indices .size * jnp .dtype (jnp .int32 ).itemsize # read indices
280+ + 2 * jnp .dtype (jnp .int32 ).itemsize # read start and end
281+ )
282+ cost_estimate = pl .CostEstimate (flops = 0 , transcendentals = 0 , bytes_accessed = bytes_accessed )
283+
270284 vector_mesh = plsc .VectorSubcoreMesh (
271285 num_cores = sc_info .num_cores ,
272286 num_subcores = sc_info .num_subcores ,
@@ -285,6 +299,7 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
285299 ),
286300 mesh = vector_mesh ,
287301 name = "sc_ragged_gather" ,
302+ cost_estimate = cost_estimate ,
288303 ** {
289304 _OUT_KW : jax .ShapeDtypeStruct ((out_size + out_pad_size , aligned_hidden_size ), dtype ),
290305 _SCRATCH_KW : [
0 commit comments