Skip to content

Commit 9c9e095

Browse files
Merge pull request #4080 from AI-Hypercomputer:chengnuojin-rs-g3
PiperOrigin-RevId: 927429580
2 parents 5d45f2e + e74d7ac commit 9c9e095

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

src/maxtext/kernels/ragged/ragged_gather_reduce.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,18 @@
2929
if Version(jax.__version__) <= Version("0.10.0"):
3030
_OUT_KW = "out_shape"
3131
_SCRATCH_KW = "scratch_shapes"
32+
_COMPILER_PARAMS = {
33+
"use_tc_tiling_on_sc": True,
34+
"disable_bounds_checks": True,
35+
}
3236
else:
3337
_OUT_KW = "out_type"
3438
_SCRATCH_KW = "scratch_types"
39+
_COMPILER_PARAMS = {
40+
"use_tc_tiling_on_sc": True,
41+
"disable_bounds_checks": True,
42+
"needs_layout_passes": False,
43+
}
3544

3645

3746
# ceil up to the nearest multiple of b.
@@ -488,9 +497,8 @@ def ragged_gather_reduce(
488497
num_row_partitions=num_rows_partitions,
489498
num_column_partitions=num_column_partitions,
490499
),
491-
compiler_params=pltpu.CompilerParams(
492-
use_tc_tiling_on_sc=True,
493-
disable_bounds_checks=True,
500+
compiler_params=pltpu.CompilerParams( # pytype: disable=wrong-keyword-args
501+
**_COMPILER_PARAMS,
494502
),
495503
mesh=vector_mesh,
496504
name="sc_ragged_gather_reduce",

0 commit comments

Comments
 (0)