File tree Expand file tree Collapse file tree
src/maxtext/kernels/ragged Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2929if Version (jax .__version__ ) <= Version ("0.10.0" ):
3030 _OUT_KW = "out_shape"
3131 _SCRATCH_KW = "scratch_shapes"
32+ _COMPILER_PARAMS = {
33+ "use_tc_tiling_on_sc" : True ,
34+ "disable_bounds_checks" : True ,
35+ }
3236else :
3337 _OUT_KW = "out_type"
3438 _SCRATCH_KW = "scratch_types"
39+ _COMPILER_PARAMS = {
40+ "use_tc_tiling_on_sc" : True ,
41+ "disable_bounds_checks" : True ,
42+ "needs_layout_passes" : False ,
43+ }
3544
3645
3746# ceil up to the nearest multiple of b.
@@ -488,9 +497,8 @@ def ragged_gather_reduce(
488497 num_row_partitions = num_rows_partitions ,
489498 num_column_partitions = num_column_partitions ,
490499 ),
491- compiler_params = pltpu .CompilerParams (
492- use_tc_tiling_on_sc = True ,
493- disable_bounds_checks = True ,
500+ compiler_params = pltpu .CompilerParams ( # pytype: disable=wrong-keyword-args
501+ ** _COMPILER_PARAMS ,
494502 ),
495503 mesh = vector_mesh ,
496504 name = "sc_ragged_gather_reduce" ,
You can’t perform that action at this time.
0 commit comments