@@ -41,19 +41,22 @@ def main_kernel(
4141 end_ref : jax .Ref ,
4242 in_hbm_ref : jax .Ref ,
4343 indices_hbm_ref : jax .Ref ,
44+ weights_hbm_ref : jax .Ref ,
4445 # Outputs.
4546 out_hbm_ref : jax .Ref ,
4647 # Scratch.
4748 start_vmem_ref : jax .Ref ,
4849 end_vmem_ref : jax .Ref ,
4950 out_vmem_ref : jax .Ref ,
5051 indices_vmem_ref : jax .Ref ,
52+ weights_vmem_ref : jax .Ref ,
5153 sem_ref : jax .Ref ,
5254 * ,
5355 core_axis_name : str ,
5456 subcore_axis_name : str ,
57+ has_weights : bool ,
5558):
56- """Core ragged gather operation"""
59+ """Core ragged gather operation with per-row weighting. """
5760 tpu_info = pltpu .get_tpu_info ()
5861 sc_info = tpu_info .sparse_core
5962 assert sc_info is not None
@@ -109,9 +112,16 @@ def _():
109112 indices_hbm_ref .at [pl .ds (row_tile_start , num_simd_lanes )],
110113 indices_vmem_ref ,
111114 )
115+ if has_weights :
116+ pltpu .sync_copy (
117+ weights_hbm_ref .at [pl .ds (row_tile_start , num_simd_lanes )],
118+ weights_vmem_ref ,
119+ )
112120
113121 # HBM to VMEM transfer.
114122 indices = indices_vmem_ref [...]
123+ if has_weights :
124+ weights = weights_vmem_ref [...]
115125
116126 dtype = out_hbm_ref .dtype
117127 dtype_bits = jax .dtypes .itemsize_bits (dtype )
@@ -189,6 +199,50 @@ def dma_write_loop(col_vmem_start):
189199 row_dst = row_src // packing
190200 out_vmem_ref [row_dst , col_slice ] = out
191201
202+ # Apply per-row weights after unpacking if needed.
203+ # For packing == 1 (float32), we can apply directly.
204+ # For packing > 1 (bf16), the data is already packed; we apply below.
205+ if has_weights :
206+ if packing == 1 :
207+ # float32 path: data is already in float32 layout, one row per sublane.
208+ for col_compute_offset in range (0 , num_lanes , num_simd_lanes ):
209+ col_slice = pl .ds (col_vmem_start + col_compute_offset , num_simd_lanes )
210+ for row_vmem in range (num_simd_lanes ):
211+ data = out_vmem_ref [row_vmem , col_slice ]
212+ data_f32 = jax .lax .bitcast_convert_type (data , jnp .float32 )
213+ data_f32 = data_f32 * weights [row_vmem ]
214+ out_vmem_ref [row_vmem , col_slice ] = jax .lax .bitcast_convert_type (data_f32 , jnp .uint32 )
215+ else :
216+ # bf16 path: data is packed, packing=2. Each packed row contains 2
217+ # bf16 values from consecutive source rows. We need to unpack each
218+ # bf16 to float32, multiply by its weight, then repack.
219+ for col_compute_offset in range (0 , num_lanes , num_simd_lanes ):
220+ col_slice = pl .ds (col_vmem_start + col_compute_offset , num_simd_lanes )
221+ for row_dst in range (num_simd_lanes // packing ):
222+ packed_data = out_vmem_ref [row_dst , col_slice ]
223+ result = jnp .zeros_like (packed_data )
224+ for sub in range (packing ):
225+ row_src = row_dst * packing + sub
226+ # Extract the sub-element.
227+ shift_right = sub * dtype_bits
228+ shift_left = sub * dtype_bits
229+ elem = jnp .bitwise_right_shift (packed_data , shift_right )
230+ elem = jnp .bitwise_and (elem , 2 ** dtype_bits - 1 )
231+ # Convert bf16 bits to float32 for weighting.
232+ # bf16 is stored in the lower 16 bits; shift to upper 16 for
233+ # bitcast to float32.
234+ elem_f32 = jnp .bitwise_left_shift (elem , 16 )
235+ elem_f32 = jax .lax .bitcast_convert_type (elem_f32 , jnp .float32 )
236+ elem_f32 = elem_f32 * weights [row_src ]
237+ # Convert back: bitcast float32 -> uint32, shift right 16 to
238+ # get bf16 bits, then shift left to target position.
239+ elem_u32 = jax .lax .bitcast_convert_type (elem_f32 , jnp .uint32 )
240+ elem_bf16 = jnp .bitwise_right_shift (elem_u32 , 16 )
241+ elem_bf16 = jnp .bitwise_and (elem_bf16 , 2 ** dtype_bits - 1 )
242+ elem_bf16 = jnp .bitwise_left_shift (elem_bf16 , shift_left )
243+ result = jnp .bitwise_or (result , elem_bf16 )
244+ out_vmem_ref [row_dst , col_slice ] = result
245+
192246 # Start dma write.
193247 for row_vmem in range (num_simd_lanes // packing ):
194248 row_hbm = row_tile_start // packing + row_vmem
@@ -234,9 +288,31 @@ def calculate_col_size(hidden_size: int) -> int:
234288 return pl .cdiv (hidden_size , (num_cols * num_lanes )) * num_lanes
235289
236290
237- @jax .jit
238- def ragged_gather (x : jax .Array , indices : jax .Array , start : jax .Array , end : jax .Array ) -> jax .Array :
239- """Perform gather on indices within dynamic array start and end."""
291+ @functools .partial (jax .jit , static_argnames = ("has_weights" ,))
292+ def ragged_gather (
293+ x : jax .Array ,
294+ indices : jax .Array ,
295+ start : jax .Array ,
296+ end : jax .Array ,
297+ weights : jax .Array | None = None ,
298+ has_weights : bool = False ,
299+ ) -> jax .Array :
300+ """Perform gather on indices within dynamic array start and end.
301+
302+ Args:
303+ x: 2D input array of shape ``(input_size, hidden_size)``.
304+ indices: 1D array of gather indices.
305+ start: Scalar or 1D array indicating the start of the valid range.
306+ end: Scalar or 1D array indicating the end of the valid range.
307+ weights: Optional 1D array of per-row weights. When provided, each
308+ gathered row is multiplied by its corresponding weight inside the
309+ kernel, avoiding an extra HBM read-write pass.
310+ has_weights: Static bool flag indicating whether ``weights`` should be
311+ applied. Must be ``True`` when ``weights`` is not ``None``.
312+
313+ Returns:
314+ Gathered output of shape ``(indices_size, hidden_size)``.
315+ """
240316
241317 assert x .ndim == 2 , "Ragged gather only supports 2d inputs."
242318 assert indices .ndim == 1 , "Ragged gather only supports 1d indices."
@@ -257,7 +333,10 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
257333 sc_info = pltpu .get_tpu_info ().sparse_core
258334 if sc_info is None :
259335 # Sparse core is not available. Fallback to regular gather.
260- return x [indices ]
336+ out = x [indices ]
337+ if has_weights :
338+ out = out * weights [:, None ]
339+ return out
261340
262341 hidden_size = x .shape [- 1 ]
263342 out_size = indices .size
@@ -271,6 +350,12 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
271350 out_pad_size = pl .cdiv (out_size , block_size ) * block_size - out_size
272351 indices = jnp .pad (indices , ((0 , out_pad_size )))
273352
353+ if has_weights :
354+ weights = jnp .pad (weights , ((0 , out_pad_size )), constant_values = 1.0 )
355+ else :
356+ # Provide a dummy weights array; the kernel won't use it.
357+ weights = jnp .ones ((out_size + out_pad_size ,), dtype = jnp .float32 )
358+
274359 aligned_hidden_size = pl .cdiv (hidden_size , col_size ) * col_size
275360
276361 vector_mesh = plsc .VectorSubcoreMesh (
@@ -284,6 +369,7 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
284369 main_kernel ,
285370 core_axis_name = vector_mesh .core_axis_name ,
286371 subcore_axis_name = vector_mesh .subcore_axis_name ,
372+ has_weights = has_weights ,
287373 ),
288374 compiler_params = pltpu .CompilerParams (
289375 use_tc_tiling_on_sc = True ,
@@ -298,7 +384,8 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
298384 pltpu .VMEM ((num_simd_lanes ,), jnp .int32 ),
299385 pltpu .VMEM ((num_simd_lanes , col_size ), jnp .uint32 ),
300386 pltpu .VMEM ((num_simd_lanes ,), jnp .int32 ),
387+ pltpu .VMEM ((num_simd_lanes ,), jnp .float32 ),
301388 pltpu .SemaphoreType .DMA ((2 ,)),
302389 ],
303390 },
304- )(start , end , x , indices )[:out_size , :hidden_size ]
391+ )(start , end , x , indices , weights )[:out_size , :hidden_size ]
0 commit comments