Skip to content

Commit da5a5c3

Browse files
committed
Support specifying custom tile sizes for forward and backward passes of Tokamax GMM in MaxText
1 parent 58ffd43 commit da5a5c3

3 files changed

Lines changed: 65 additions & 16 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,9 @@ class Attention(BaseModel):
557557
False,
558558
description="Whether to use the Tokamax library for GMM kernel implementation.",
559559
)
560+
tokamax_gmm_autotune: bool = Field(
561+
False, description="Whether to use tokamax auto-tuner for GMM."
562+
)
560563
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
561564
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
562565
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")

src/maxtext/layers/moe.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from qwix.contrib.sparsity import sparsity_module
4444
import qwix.pallas as qpl
4545
import tokamax
46+
from tokamax import config as tokamax_config
47+
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config
4648

4749
set_xla_metadata = xla_metadata.set_xla_metadata
4850

@@ -1104,14 +1106,35 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11041106
weight_gather_axes=weight_gather_axes,
11051107
)
11061108
else: # tokamax (unquantized)
1107-
output = tokamax.ragged_dot(
1108-
lhs=inputs,
1109-
rhs=kernel,
1110-
group_sizes=tokamax_group_sizes,
1111-
precision=jax.lax.Precision.DEFAULT,
1112-
preferred_element_type=self.dtype,
1113-
implementation="mosaic",
1114-
)
1109+
if self.config.tokamax_gmm_autotune:
1110+
with tokamax_config.autotuning_cache_miss_fallback("heuristics"):
1111+
output = tokamax.ragged_dot(
1112+
lhs=inputs,
1113+
rhs=kernel,
1114+
group_sizes=tokamax_group_sizes,
1115+
precision=jax.lax.Precision.DEFAULT,
1116+
preferred_element_type=self.dtype,
1117+
implementation="mosaic",
1118+
)
1119+
else:
1120+
custom_impl = PallasMosaicTpuRaggedDot(
1121+
config=Config(
1122+
# Forward Pass
1123+
gmm_tiling=(tiling[0], tiling[1], tiling[2]),
1124+
# Backward DLHS Pass
1125+
gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]),
1126+
# Backward DRHS Pass
1127+
tgmm_tiling=(tiling[6], tiling[7], tiling[8]),
1128+
)
1129+
)
1130+
output = tokamax.ragged_dot(
1131+
lhs=inputs,
1132+
rhs=kernel,
1133+
group_sizes=tokamax_group_sizes,
1134+
precision=jax.lax.Precision.DEFAULT,
1135+
preferred_element_type=self.dtype,
1136+
implementation=custom_impl,
1137+
)
11151138
elif self.config.megablox: # Older forked megablox
11161139
output = mblx.gmm(
11171140
lhs=inputs,

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from maxtext.layers import quantizations
3030
import qwix.pallas as qpl
3131
import tokamax
32+
from tokamax import config as tokamax_config
33+
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config
3234

3335

3436
@functools.partial(
@@ -962,14 +964,35 @@ def gmm(
962964
qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config),
963965
)
964966
else:
965-
output = tokamax.ragged_dot(
966-
lhs=inputs,
967-
rhs=kernel,
968-
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
969-
precision=jax.lax.Precision.DEFAULT,
970-
preferred_element_type=preferred_element_type,
971-
implementation="mosaic",
972-
)
967+
if config.tokamax_gmm_autotune:
968+
with tokamax_config.autotuning_cache_miss_fallback("heuristics"):
969+
output = tokamax.ragged_dot(
970+
lhs=inputs,
971+
rhs=kernel,
972+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
973+
precision=jax.lax.Precision.DEFAULT,
974+
preferred_element_type=preferred_element_type,
975+
implementation="mosaic",
976+
)
977+
else:
978+
custom_impl = PallasMosaicTpuRaggedDot(
979+
config=Config(
980+
# Forward Pass
981+
gmm_tiling=(tiling[0], tiling[1], tiling[2]),
982+
# Backward DLHS Pass
983+
gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]),
984+
# Backward DRHS Pass
985+
tgmm_tiling=(tiling[6], tiling[7], tiling[8]),
986+
)
987+
)
988+
output = tokamax.ragged_dot(
989+
lhs=inputs,
990+
rhs=kernel,
991+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
992+
precision=jax.lax.Precision.DEFAULT,
993+
preferred_element_type=preferred_element_type,
994+
implementation=custom_impl,
995+
)
973996
return output
974997

975998
gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)

0 commit comments

Comments
 (0)