@@ -99,9 +99,10 @@ def main_kernel(
9999 # Inputs.
100100 num_rows_per_row_partition_ref : jax .Ref ,
101101 in_hbm_ref : jax .Ref ,
102- src_indices_hbm_ref : jax .Ref ,
102+ indices_hbm_ref : jax .Ref ,
103103 dst_indices_hbm_ref : jax .Ref ,
104104 topk_weights_hbm_ref : jax .Ref ,
105+ sorted_by_validity_hbm_ref : jax .Ref ,
105106 # Outputs.
106107 out_hbm_ref : jax .Ref ,
107108 # Scratch.
@@ -111,6 +112,7 @@ def main_kernel(
111112 src_indices_vmem_ref : jax .Ref ,
112113 dst_indices_vmem_ref : jax .Ref ,
113114 topk_weights_vmem_ref : jax .Ref ,
115+ sorted_by_validity_vmem_ref : jax .Ref ,
114116 sem_ref : jax .Ref ,
115117 * ,
116118 core_axis_name : str ,
@@ -176,9 +178,9 @@ def row_loop(row_block_id):
176178 dma_list = []
177179 dma_list .append (
178180 pltpu .make_async_copy (
179- src_indices_hbm_ref .at [pl .ds (row_tile_start ,
180- num_simd_lanes )],
181- src_indices_vmem_ref ,
181+ sorted_by_validity_hbm_ref .at [pl .ds (
182+ row_tile_start , num_simd_lanes )],
183+ sorted_by_validity_vmem_ref ,
182184 recv_sem ,
183185 ))
184186 dma_list .append (
@@ -188,13 +190,22 @@ def row_loop(row_block_id):
188190 dst_indices_vmem_ref ,
189191 recv_sem ,
190192 ))
193+ jax .tree .map (lambda x : x .start (), dma_list )
194+ jax .tree .map (lambda x : x .wait (), dma_list )
195+
196+ dma_list = []
191197 dma_list .append (
192198 pltpu .make_async_copy (
193- topk_weights_hbm_ref .at [pl .ds (row_tile_start ,
194- num_simd_lanes )],
199+ topk_weights_hbm_ref .at [sorted_by_validity_vmem_ref ],
195200 topk_weights_vmem_ref ,
196201 recv_sem ,
197202 ))
203+ dma_list .append (
204+ pltpu .make_async_copy (
205+ indices_hbm_ref .at [sorted_by_validity_vmem_ref ],
206+ src_indices_vmem_ref ,
207+ recv_sem ,
208+ ))
198209 jax .tree .map (lambda x : x .start (), dma_list )
199210 jax .tree .map (lambda x : x .wait (), dma_list )
200211
@@ -227,9 +238,12 @@ def row_loop(row_block_id):
227238
228239 # VMEM to HBM transfer.
229240 # Use dynamic loop to minimize register spills.
241+ @pl .loop (0 ,
242+ col_size ,
243+ step = num_lanes ,
244+ init_carry = (prev_dst_row_hbm , ))
230245 @jax .named_scope ("dma_write_loop" )
231- def dma_write_loop (i , carry ):
232- col_vmem_start = i * num_lanes
246+ def dma_write_loop (col_vmem_start , carry ):
233247 col_hbm_start = col_start + col_vmem_start
234248
235249 for _ in range (num_simd_lanes ):
@@ -359,12 +373,6 @@ def dma_write_loop(i, carry):
359373
360374 return carry
361375
362- jax .lax .fori_loop (
363- 0 ,
364- pl .cdiv (col_size , num_lanes ),
365- dma_write_loop ,
366- init_val = (prev_dst_row_hbm , ),
367- )
368376 # Wait for dma write to finish.
369377 for _ in range (0 , col_size , num_lanes ):
370378 for _ in range (num_simd_lanes ):
@@ -380,12 +388,11 @@ def dma_write_loop(i, carry):
380388# TODO(gxd): investigate if we can make the preprocessing more efficient.
381389def _preprocess (
382390 indices : jax .Array ,
383- topk_weights : jax .Array ,
384391 valid_rows_mask : jax .Array ,
385392 reduce_group_size : int ,
386393 num_row_partitions : int ,
387394 num_simd_lanes : int ,
388- ) -> tuple [jax .Array , jax .Array , jax .Array , jax .Array , jax . Array ]:
395+ ) -> tuple [jax .Array , jax .Array , jax .Array , jax .Array ]:
389396 """Preprocesses indices for ragged gather reduce."""
390397 assert indices .ndim == 1 , "Ragged scatter only supports 1d indices."
391398
@@ -403,12 +410,10 @@ def _preprocess(
403410 ) * row_partition_size )
404411 sorted_by_validity = sorted_by_validity .reshape (- 1 )
405412
406- src_indices = indices [sorted_by_validity ]
407413 # `reduce_group_size` source rows are mapped (and reduced) to the same output
408414 # row.
409415 dst_indices = sorted_by_validity // reduce_group_size
410- topk_weights = topk_weights [sorted_by_validity ]
411- topk_weights = topk_weights .astype (jnp .float32 )
416+ sorted_by_validity = sorted_by_validity .astype (jnp .int32 )
412417
413418 num_src_rows_per_row_partition = jnp .sum (valid_rows_mask , axis = - 1 )
414419 assert num_row_partitions <= num_simd_lanes
@@ -421,9 +426,8 @@ def _preprocess(
421426 mask = jnp .any (valid_rows_mask .reshape (- 1 , reduce_group_size ), axis = - 1 )
422427
423428 return (
424- src_indices ,
425429 dst_indices ,
426- topk_weights ,
430+ sorted_by_validity ,
427431 num_src_rows_per_row_partition ,
428432 mask ,
429433 )
@@ -521,14 +525,12 @@ def ragged_gather_reduce(
521525 col_size = x .shape [- 1 ] // num_column_partitions
522526
523527 (
524- src_indices ,
525528 dst_indices ,
526- topk_weights ,
529+ sorted_by_validity ,
527530 num_src_rows_per_row_partition ,
528531 mask ,
529532 ) = _preprocess (
530533 indices ,
531- topk_weights ,
532534 valid_rows_mask ,
533535 reduce_group_size ,
534536 num_row_partitions ,
@@ -566,12 +568,19 @@ def ragged_gather_reduce(
566568 pltpu .VMEM ((num_simd_lanes , ), jnp .int32 ),
567569 pltpu .VMEM ((num_simd_lanes , ), jnp .int32 ),
568570 pltpu .VMEM ((num_simd_lanes , ), jnp .float32 ),
571+ pltpu .VMEM ((num_simd_lanes , ), jnp .int32 ),
569572 pltpu .SemaphoreType .DMA ((2 , )),
570573 ],
571574 mesh = vector_mesh ,
572575 name = "sc_ragged_gather_reduce" ,
573- )(num_src_rows_per_row_partition , x , src_indices , dst_indices ,
574- topk_weights )
576+ )(
577+ num_src_rows_per_row_partition ,
578+ x ,
579+ indices ,
580+ dst_indices ,
581+ topk_weights .astype (jnp .float32 ),
582+ sorted_by_validity ,
583+ )
575584
576585 # If there is no valid source row in a reduce group, set that group's output
577586 # to zero.
0 commit comments