Skip to content

Commit 682db7f

Browse files
committed
Enable Tokamax GMM with autotuning fallback in MaxText
1 parent 661a153 commit 682db7f

3 files changed

Lines changed: 43 additions & 16 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ class Attention(BaseModel):
561561
False,
562562
description="Whether to use the Tokamax library for GMM kernel implementation.",
563563
)
564+
tokamax_gmm_autotune: bool = Field(False, description="Whether to use tokamax auto-tuner for GMM.")
564565
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
565566
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
566567
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")

src/maxtext/layers/moe.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from qwix.contrib.sparsity import sparsity_module
4444
import qwix.pallas as qpl
4545
import tokamax
46+
from tokamax import config as tokamax_config
47+
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config
4648

4749
set_xla_metadata = xla_metadata.set_xla_metadata
4850

@@ -1121,14 +1123,25 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11211123
weight_gather_axes=weight_gather_axes,
11221124
)
11231125
else: # tokamax (unquantized)
1124-
output = tokamax.ragged_dot(
1125-
lhs=inputs,
1126-
rhs=kernel,
1127-
group_sizes=tokamax_group_sizes,
1128-
precision=jax.lax.Precision.DEFAULT,
1129-
preferred_element_type=self.dtype,
1130-
implementation="mosaic",
1131-
)
1126+
if self.config.tokamax_gmm_autotune:
1127+
with tokamax_config.autotuning_cache_miss_fallback("autotune"):
1128+
output = tokamax.ragged_dot(
1129+
lhs=inputs,
1130+
rhs=kernel,
1131+
group_sizes=tokamax_group_sizes,
1132+
precision=jax.lax.Precision.DEFAULT,
1133+
preferred_element_type=self.dtype,
1134+
implementation="mosaic",
1135+
)
1136+
else:
1137+
output = tokamax.ragged_dot(
1138+
lhs=inputs,
1139+
rhs=kernel,
1140+
group_sizes=tokamax_group_sizes,
1141+
precision=jax.lax.Precision.DEFAULT,
1142+
preferred_element_type=self.dtype,
1143+
implementation="mosaic",
1144+
)
11321145
elif self.config.megablox: # Older forked megablox
11331146
output = mblx.gmm(
11341147
lhs=inputs,

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from maxtext.layers import quantizations
3030
import qwix.pallas as qpl
3131
import tokamax
32+
from tokamax import config as tokamax_config
33+
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config
3234

3335

3436
@functools.partial(
@@ -962,14 +964,25 @@ def gmm(
962964
qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0],
963965
)
964966
else:
965-
output = tokamax.ragged_dot(
966-
lhs=inputs,
967-
rhs=kernel,
968-
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
969-
precision=jax.lax.Precision.DEFAULT,
970-
preferred_element_type=preferred_element_type,
971-
implementation="mosaic",
972-
)
967+
if config.tokamax_gmm_autotune:
968+
with tokamax_config.autotuning_cache_miss_fallback("autotune"):
969+
output = tokamax.ragged_dot(
970+
lhs=inputs,
971+
rhs=kernel,
972+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
973+
precision=jax.lax.Precision.DEFAULT,
974+
preferred_element_type=preferred_element_type,
975+
implementation="mosaic",
976+
)
977+
else:
978+
output = tokamax.ragged_dot(
979+
lhs=inputs,
980+
rhs=kernel,
981+
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
982+
precision=jax.lax.Precision.DEFAULT,
983+
preferred_element_type=preferred_element_type,
984+
implementation="mosaic",
985+
)
973986
return output
974987

975988
gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)

0 commit comments

Comments
 (0)