Skip to content

Commit 796ed9d

Browse files
committed
refactor(moe): Remove tokamax_gmm_autotune and unconditionally use custom GMM tile sizes with Pallas subclass
This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch and separate WI/WO subclasses with a single, reusable `PallasMosaicTpuRaggedDotCustom` class defined in `moe.py`. 2. Implements custom __post_init__ in this subclass to ensure JAX backward passes (VJP) preserve tile configurations correctly. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in `moe_test.py` to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
1 parent 504dcb0 commit 796ed9d

3 files changed

Lines changed: 108 additions & 101 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 67 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,46 @@
5555
DLHS_RAGGED_DOT_DIM_NUMS,
5656
DRHS_RAGGED_DOT_DIM_NUMS,
5757
)
58+
from tokamax._src.ops.ragged_dot import base
5859
from tokamax._src.ops import op
60+
import dataclasses
61+
62+
63+
@jax.tree_util.register_dataclass
64+
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
65+
class PallasMosaicTpuRaggedDotCustom(PallasMosaicTpuRaggedDot):
66+
config: Config | None = None
67+
fwd_tile: tuple[int, int, int] = (128, 128, 128)
68+
dlhs_tile: tuple[int, int, int] = (128, 128, 128)
69+
drhs_tile: tuple[int, int, int] = (128, 128, 128)
70+
71+
def __post_init__(self):
72+
from tokamax._src.ops.ragged_dot import base
73+
qdtype = self.qdtype if self.qdtype is None else jnp.dtype(self.qdtype).name
74+
if self.vjp is None:
75+
fn = lambda *args, **kw: PallasMosaicTpuRaggedDotCustom(
76+
qdtype=qdtype,
77+
interpret=self.interpret,
78+
fwd_tile=self.fwd_tile,
79+
dlhs_tile=self.dlhs_tile,
80+
drhs_tile=self.drhs_tile,
81+
)(*args, **kw)
82+
object.__setattr__(
83+
self,
84+
"vjp",
85+
functools.partial(base.vjp, dlhs_ragged_dot=fn, drhs_ragged_dot=fn),
86+
)
87+
88+
def _get_heuristics_config(self, ba) -> Config:
89+
dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS)
90+
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
91+
return Config(tile_m=self.fwd_tile[0], tile_k=self.fwd_tile[1], tile_n=self.fwd_tile[2])
92+
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
93+
return Config(tile_m=self.dlhs_tile[0], tile_k=self.dlhs_tile[1], tile_n=self.dlhs_tile[2])
94+
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
95+
return Config(tile_m=self.drhs_tile[0], tile_k=self.drhs_tile[1], tile_n=self.drhs_tile[2])
96+
return Config()
97+
5998

6099
set_xla_metadata = xla_metadata.set_xla_metadata
61100

@@ -557,8 +596,7 @@ def __init__(
557596
):
558597
self.wo.value = self.wo.value * self.per_expert_scale.value[:, None, None]
559598

560-
# Monkey-patch Tokamax heuristics globally once
561-
_monkey_patch_tokamax_heuristics(self.config)
599+
562600

563601
def _maybe_shard_with_logical(self, inputs, logical_name):
564602
return maybe_shard_with_logical(
@@ -1095,6 +1133,18 @@ def sparse_matmul(
10951133
wo_bias,
10961134
):
10971135
"""Perform sparse matrix multiplication of inputs and Experts."""
1136+
config = self.config
1137+
1138+
gmm_impl_wi = PallasMosaicTpuRaggedDotCustom(
1139+
fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim),
1140+
dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim),
1141+
drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim),
1142+
)
1143+
gmm_impl_wo = PallasMosaicTpuRaggedDotCustom(
1144+
fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim),
1145+
dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim),
1146+
drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim),
1147+
)
10981148

10991149
def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount):
11001150
"""Execute jax.lax.ragged_dot, with potential quantization"""
@@ -1139,6 +1189,15 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments,
11391189
output *= scales
11401190
return output
11411191

1192+
def get_gmm_group_sizes(inputs, kernel, ep):
1193+
# Calculates perfectly balanced group sizes where each local expert receives an equal
1194+
# share of local tokens, adjusted for expert parallelism.
1195+
#
1196+
# Note: This function assumes the inputs are ragged and padded to the worst-case size
1197+
# (which is generally a factor of EP larger than perfectly balanced). This is why we must
1198+
# divide by EP.
1199+
return (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0]
1200+
11421201
def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11431202
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
11441203
if self.config.use_qwix_quantization or (
@@ -1151,7 +1210,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11511210
ep = self.get_expert_parallelism_size()
11521211
return tokamax.RaggedDotGroupSizes(
11531212
group_sizes,
1154-
(inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0],
1213+
get_gmm_group_sizes(inputs, kernel, ep),
11551214
)
11561215

11571216
def get_quantization_dtypes():
@@ -1162,7 +1221,7 @@ def get_quantization_dtypes():
11621221
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
11631222
return lhs_quantize_dtype, rhs_quantize_dtype
11641223

1165-
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
1224+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, gmm_impl=None):
11661225
if inputs.shape[0] != expert_assignments.shape[0]:
11671226
raise ValueError("The number of input tokens must match the number of expert assignments!")
11681227

@@ -1196,7 +1255,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11961255
group_sizes=tokamax_group_sizes,
11971256
precision=jax.lax.Precision.DEFAULT,
11981257
preferred_element_type=self.dtype,
1199-
implementation="mosaic",
1258+
implementation="mosaic" if gmm_impl is None else [gmm_impl],
12001259
)
12011260
elif self.config.megablox: # Older forked megablox
12021261
output = mblx.gmm(
@@ -1485,6 +1544,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14851544
w0,
14861545
tiling=wi_tile_size,
14871546
weight_gather_axes=wi_gather_axes,
1547+
gmm_impl=gmm_impl_wi,
14881548
)
14891549
if self.get_tensor_transpose_parallelism_size() > 1:
14901550
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
@@ -1497,6 +1557,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14971557
w1,
14981558
tiling=wi_tile_size,
14991559
weight_gather_axes=wi_gather_axes,
1560+
gmm_impl=gmm_impl_wi,
15001561
)
15011562
if self.get_tensor_transpose_parallelism_size() > 1:
15021563
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
@@ -1510,6 +1571,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15101571
wo,
15111572
tiling=wo_tile_size,
15121573
weight_gather_axes=wo_gather_axes,
1574+
gmm_impl=gmm_impl_wo,
15131575
)
15141576
if self.get_tensor_parallelism_size() > 1:
15151577
intermediate_output = jax.lax.psum_scatter(
@@ -2553,74 +2615,3 @@ def get_routed_and_shared_moe(
25532615
abstract_init=False,
25542616
)
25552617
return module
2556-
2557-
2558-
_heuristics_patched = False
2559-
2560-
2561-
def _monkey_patch_tokamax_heuristics(config, force=False):
2562-
"""Globally monkey-patches Tokamax GMM heuristics with manual tiling overrides."""
2563-
global _heuristics_patched
2564-
if _heuristics_patched and not force:
2565-
return
2566-
2567-
def custom_heuristics(self, ba: op.BoundArguments) -> Config:
2568-
lhs, rhs = ba.arguments["lhs"], ba.arguments["rhs"]
2569-
dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS)
2570-
2571-
is_wo = False
2572-
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
2573-
is_wo = rhs.shape[1] == config.base_mlp_dim
2574-
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
2575-
is_wo = rhs.shape[2] == config.base_emb_dim
2576-
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
2577-
is_wo = lhs.shape[1] == config.base_mlp_dim
2578-
2579-
if is_wo:
2580-
# Return wo tile sizes
2581-
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
2582-
return Config(
2583-
tile_m=config.wo_tile_fwd_batch_seq,
2584-
tile_k=config.wo_tile_fwd_mlp_dim,
2585-
tile_n=config.wo_tile_fwd_embed_dim,
2586-
)
2587-
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
2588-
return Config(
2589-
tile_m=config.wo_tile_dlhs_batch_seq,
2590-
tile_k=config.wo_tile_dlhs_embed_dim,
2591-
tile_n=config.wo_tile_dlhs_mlp_dim,
2592-
)
2593-
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
2594-
return Config(
2595-
tile_m=config.wo_tile_drhs_batch_seq,
2596-
tile_k=config.wo_tile_drhs_mlp_dim,
2597-
tile_n=config.wo_tile_drhs_embed_dim,
2598-
)
2599-
else:
2600-
# Return wi tile sizes
2601-
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
2602-
return Config(
2603-
tile_m=config.wi_tile_fwd_batch_seq,
2604-
tile_k=config.wi_tile_fwd_embed_dim,
2605-
tile_n=config.wi_tile_fwd_mlp_dim,
2606-
)
2607-
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
2608-
return Config(
2609-
tile_m=config.wi_tile_dlhs_batch_seq,
2610-
tile_k=config.wi_tile_dlhs_mlp_dim,
2611-
tile_n=config.wi_tile_dlhs_embed_dim,
2612-
)
2613-
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
2614-
return Config(
2615-
tile_m=config.wi_tile_drhs_batch_seq,
2616-
tile_k=config.wi_tile_drhs_embed_dim,
2617-
tile_n=config.wi_tile_drhs_mlp_dim,
2618-
)
2619-
2620-
return Config()
2621-
2622-
# Apply class-level monkey patch!
2623-
# pylint: disable=protected-access
2624-
PallasMosaicTpuRaggedDot._get_heuristics_config = custom_heuristics
2625-
_heuristics_patched = True
2626-
print("[TOKAMAX_PATCH] Successfully monkey-patched Tokamax GMM heuristics globally!")

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from maxtext.layers import quantizations
3030
import qwix.pallas as qpl
3131
import tokamax
32-
from maxtext.layers.moe import _monkey_patch_tokamax_heuristics
32+
3333

3434

3535
@functools.partial(
@@ -833,9 +833,6 @@ def moe(
833833
config,
834834
quant,
835835
):
836-
"""Performs dropless MoE with tensor/expert parallelism."""
837-
# Monkey-patch Tokamax heuristics globally once
838-
_monkey_patch_tokamax_heuristics(config)
839836
xs, ys = list(zip(*inputs))
840837
ys = with_data_parallel_constraint(
841838
process_activations(
@@ -943,6 +940,16 @@ def unroute(
943940

944941
def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh):
945942
"""Processes routed tokens through the MLP."""
943+
gmm_impl_wi = moe_lib.PallasMosaicTpuRaggedDotCustom(
944+
fwd_tile=(config.wi_tile_fwd_batch_seq, config.wi_tile_fwd_embed_dim, config.wi_tile_fwd_mlp_dim),
945+
dlhs_tile=(config.wi_tile_dlhs_batch_seq, config.wi_tile_dlhs_mlp_dim, config.wi_tile_dlhs_embed_dim),
946+
drhs_tile=(config.wi_tile_drhs_batch_seq, config.wi_tile_drhs_embed_dim, config.wi_tile_drhs_mlp_dim),
947+
)
948+
gmm_impl_wo = moe_lib.PallasMosaicTpuRaggedDotCustom(
949+
fwd_tile=(config.wo_tile_fwd_batch_seq, config.wo_tile_fwd_mlp_dim, config.wo_tile_fwd_embed_dim),
950+
dlhs_tile=(config.wo_tile_dlhs_batch_seq, config.wo_tile_dlhs_embed_dim, config.wo_tile_dlhs_mlp_dim),
951+
drhs_tile=(config.wo_tile_drhs_batch_seq, config.wo_tile_drhs_mlp_dim, config.wo_tile_drhs_embed_dim),
952+
)
946953

947954
def gmm(
948955
inputs,
@@ -951,6 +958,7 @@ def gmm(
951958
group_sizes,
952959
preferred_element_type,
953960
weight_gather_axes,
961+
gmm_impl=None,
954962
):
955963
if config.use_qwix_quantization:
956964
output = megablox.gmm(
@@ -971,7 +979,7 @@ def gmm(
971979
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
972980
precision=jax.lax.Precision.DEFAULT,
973981
preferred_element_type=preferred_element_type,
974-
implementation="mosaic",
982+
implementation="mosaic" if gmm_impl is None else [gmm_impl],
975983
)
976984
return output
977985

@@ -1031,6 +1039,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10311039
w01,
10321040
tiling=wi_tile_size,
10331041
weight_gather_axes=wi_gather_axes,
1042+
gmm_impl=gmm_impl_wi,
10341043
)
10351044
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
10361045
else:
@@ -1039,12 +1048,14 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10391048
w0,
10401049
tiling=wi_tile_size,
10411050
weight_gather_axes=wi_gather_axes,
1051+
gmm_impl=gmm_impl_wi,
10421052
)
10431053
layer_w1 = gmm_fn(
10441054
x,
10451055
w1,
10461056
tiling=wi_tile_size,
10471057
weight_gather_axes=wi_gather_axes,
1058+
gmm_impl=gmm_impl_wi,
10481059
)
10491060
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
10501061
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
@@ -1055,6 +1066,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10551066
wo,
10561067
tiling=wo_tile_size,
10571068
weight_gather_axes=wo_gather_axes,
1069+
gmm_impl=gmm_impl_wo,
10581070
)
10591071
return layer_wo
10601072

0 commit comments

Comments
 (0)