Skip to content

Commit cfeea9c

Browse files
committed
refactor(moe): Remove tokamax_gmm_autotune and unconditionally use custom tile sizes with Pallas subclasses
This change addresses review comments on PR #3779: 1. Replaces the brittle global monkey-patch with clean Pallas subclasses (PallasMosaicTpuRaggedDotWI and PallasMosaicTpuRaggedDotWO) in moe.py and deepseek_batchsplit_fp8.py. 2. Implements custom __post_init__ in these subclasses to ensure JAX backward passes (VJP) use the correct subclass instead of reverting to the base class. 3. Removes the configuration flag `tokamax_gmm_autotune` entirely (unconditionally enabling custom tile sizes when Tokamax GMM is active). 4. Updates unit tests in moe_test.py to verify the new subclass-based tiling overrides. CONV=2c6843af-dcf7-403b-b67e-2fedd5f81b95
1 parent 504dcb0 commit cfeea9c

3 files changed

Lines changed: 249 additions & 99 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 93 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
DLHS_RAGGED_DOT_DIM_NUMS,
5656
DRHS_RAGGED_DOT_DIM_NUMS,
5757
)
58+
from tokamax._src.ops.ragged_dot import base
5859
from tokamax._src.ops import op
5960

6061
set_xla_metadata = xla_metadata.set_xla_metadata
@@ -557,8 +558,7 @@ def __init__(
557558
):
558559
self.wo.value = self.wo.value * self.per_expert_scale.value[:, None, None]
559560

560-
# Monkey-patch Tokamax heuristics globally once
561-
_monkey_patch_tokamax_heuristics(self.config)
561+
562562

563563
def _maybe_shard_with_logical(self, inputs, logical_name):
564564
return maybe_shard_with_logical(
@@ -1095,6 +1095,86 @@ def sparse_matmul(
10951095
wo_bias,
10961096
):
10971097
"""Perform sparse matrix multiplication of inputs and Experts."""
1098+
config = self.config
1099+
1100+
class PallasMosaicTpuRaggedDotWI(PallasMosaicTpuRaggedDot):
1101+
1102+
def __post_init__(self):
1103+
from tokamax._src.ops.ragged_dot import base
1104+
qdtype = self.qdtype if self.qdtype is None else jnp.dtype(self.qdtype).name
1105+
if self.vjp is None:
1106+
fn = lambda *args, **kw: PallasMosaicTpuRaggedDotWI(
1107+
qdtype=qdtype,
1108+
interpret=self.interpret,
1109+
)(*args, **kw)
1110+
object.__setattr__(
1111+
self,
1112+
"vjp",
1113+
functools.partial(base.vjp, dlhs_ragged_dot=fn, drhs_ragged_dot=fn),
1114+
)
1115+
1116+
def _get_heuristics_config(self, ba) -> Config:
1117+
dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS)
1118+
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
1119+
return Config(
1120+
tile_m=config.wi_tile_fwd_batch_seq,
1121+
tile_k=config.wi_tile_fwd_embed_dim,
1122+
tile_n=config.wi_tile_fwd_mlp_dim,
1123+
)
1124+
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
1125+
return Config(
1126+
tile_m=config.wi_tile_dlhs_batch_seq,
1127+
tile_k=config.wi_tile_dlhs_mlp_dim,
1128+
tile_n=config.wi_tile_dlhs_embed_dim,
1129+
)
1130+
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
1131+
return Config(
1132+
tile_m=config.wi_tile_drhs_batch_seq,
1133+
tile_k=config.wi_tile_drhs_embed_dim,
1134+
tile_n=config.wi_tile_drhs_mlp_dim,
1135+
)
1136+
return Config()
1137+
1138+
class PallasMosaicTpuRaggedDotWO(PallasMosaicTpuRaggedDot):
1139+
1140+
def __post_init__(self):
1141+
from tokamax._src.ops.ragged_dot import base
1142+
qdtype = self.qdtype if self.qdtype is None else jnp.dtype(self.qdtype).name
1143+
if self.vjp is None:
1144+
fn = lambda *args, **kw: PallasMosaicTpuRaggedDotWO(
1145+
qdtype=qdtype,
1146+
interpret=self.interpret,
1147+
)(*args, **kw)
1148+
object.__setattr__(
1149+
self,
1150+
"vjp",
1151+
functools.partial(base.vjp, dlhs_ragged_dot=fn, drhs_ragged_dot=fn),
1152+
)
1153+
1154+
def _get_heuristics_config(self, ba) -> Config:
1155+
dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS)
1156+
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
1157+
return Config(
1158+
tile_m=config.wo_tile_fwd_batch_seq,
1159+
tile_k=config.wo_tile_fwd_mlp_dim,
1160+
tile_n=config.wo_tile_fwd_embed_dim,
1161+
)
1162+
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
1163+
return Config(
1164+
tile_m=config.wo_tile_dlhs_batch_seq,
1165+
tile_k=config.wo_tile_dlhs_embed_dim,
1166+
tile_n=config.wo_tile_dlhs_mlp_dim,
1167+
)
1168+
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
1169+
return Config(
1170+
tile_m=config.wo_tile_drhs_batch_seq,
1171+
tile_k=config.wo_tile_drhs_mlp_dim,
1172+
tile_n=config.wo_tile_drhs_embed_dim,
1173+
)
1174+
return Config()
1175+
1176+
gmm_impl_wi = PallasMosaicTpuRaggedDotWI(qdtype=None, interpret=False)
1177+
gmm_impl_wo = PallasMosaicTpuRaggedDotWO(qdtype=None, interpret=False)
10981178

10991179
def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount):
11001180
"""Execute jax.lax.ragged_dot, with potential quantization"""
@@ -1139,6 +1219,11 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments,
11391219
output *= scales
11401220
return output
11411221

1222+
def get_gmm_group_sizes(inputs, kernel, ep):
1223+
# Calculates perfectly balanced group sizes where each local expert receives an equal
1224+
# share of local tokens, adjusted for expert parallelism.
1225+
return (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0]
1226+
11421227
def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11431228
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
11441229
if self.config.use_qwix_quantization or (
@@ -1151,7 +1236,7 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
11511236
ep = self.get_expert_parallelism_size()
11521237
return tokamax.RaggedDotGroupSizes(
11531238
group_sizes,
1154-
(inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0],
1239+
get_gmm_group_sizes(inputs, kernel, ep),
11551240
)
11561241

11571242
def get_quantization_dtypes():
@@ -1162,7 +1247,7 @@ def get_quantization_dtypes():
11621247
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
11631248
return lhs_quantize_dtype, rhs_quantize_dtype
11641249

1165-
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
1250+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, gmm_impl=None):
11661251
if inputs.shape[0] != expert_assignments.shape[0]:
11671252
raise ValueError("The number of input tokens must match the number of expert assignments!")
11681253

@@ -1196,7 +1281,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11961281
group_sizes=tokamax_group_sizes,
11971282
precision=jax.lax.Precision.DEFAULT,
11981283
preferred_element_type=self.dtype,
1199-
implementation="mosaic",
1284+
implementation="mosaic" if gmm_impl is None else [gmm_impl],
12001285
)
12011286
elif self.config.megablox: # Older forked megablox
12021287
output = mblx.gmm(
@@ -1485,6 +1570,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14851570
w0,
14861571
tiling=wi_tile_size,
14871572
weight_gather_axes=wi_gather_axes,
1573+
gmm_impl=gmm_impl_wi,
14881574
)
14891575
if self.get_tensor_transpose_parallelism_size() > 1:
14901576
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
@@ -1497,6 +1583,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14971583
w1,
14981584
tiling=wi_tile_size,
14991585
weight_gather_axes=wi_gather_axes,
1586+
gmm_impl=gmm_impl_wi,
15001587
)
15011588
if self.get_tensor_transpose_parallelism_size() > 1:
15021589
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
@@ -1510,6 +1597,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
15101597
wo,
15111598
tiling=wo_tile_size,
15121599
weight_gather_axes=wo_gather_axes,
1600+
gmm_impl=gmm_impl_wo,
15131601
)
15141602
if self.get_tensor_parallelism_size() > 1:
15151603
intermediate_output = jax.lax.psum_scatter(
@@ -2553,74 +2641,3 @@ def get_routed_and_shared_moe(
25532641
abstract_init=False,
25542642
)
25552643
return module
2556-
2557-
2558-
_heuristics_patched = False
2559-
2560-
2561-
def _monkey_patch_tokamax_heuristics(config, force=False):
2562-
"""Globally monkey-patches Tokamax GMM heuristics with manual tiling overrides."""
2563-
global _heuristics_patched
2564-
if _heuristics_patched and not force:
2565-
return
2566-
2567-
def custom_heuristics(self, ba: op.BoundArguments) -> Config:
2568-
lhs, rhs = ba.arguments["lhs"], ba.arguments["rhs"]
2569-
dims = ba.arguments.get("ragged_dot_dimension_numbers", DEFAULT_RAGGED_DOT_DIM_NUMS)
2570-
2571-
is_wo = False
2572-
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
2573-
is_wo = rhs.shape[1] == config.base_mlp_dim
2574-
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
2575-
is_wo = rhs.shape[2] == config.base_emb_dim
2576-
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
2577-
is_wo = lhs.shape[1] == config.base_mlp_dim
2578-
2579-
if is_wo:
2580-
# Return wo tile sizes
2581-
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
2582-
return Config(
2583-
tile_m=config.wo_tile_fwd_batch_seq,
2584-
tile_k=config.wo_tile_fwd_mlp_dim,
2585-
tile_n=config.wo_tile_fwd_embed_dim,
2586-
)
2587-
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
2588-
return Config(
2589-
tile_m=config.wo_tile_dlhs_batch_seq,
2590-
tile_k=config.wo_tile_dlhs_embed_dim,
2591-
tile_n=config.wo_tile_dlhs_mlp_dim,
2592-
)
2593-
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
2594-
return Config(
2595-
tile_m=config.wo_tile_drhs_batch_seq,
2596-
tile_k=config.wo_tile_drhs_mlp_dim,
2597-
tile_n=config.wo_tile_drhs_embed_dim,
2598-
)
2599-
else:
2600-
# Return wi tile sizes
2601-
if dims == DEFAULT_RAGGED_DOT_DIM_NUMS:
2602-
return Config(
2603-
tile_m=config.wi_tile_fwd_batch_seq,
2604-
tile_k=config.wi_tile_fwd_embed_dim,
2605-
tile_n=config.wi_tile_fwd_mlp_dim,
2606-
)
2607-
elif dims == DLHS_RAGGED_DOT_DIM_NUMS:
2608-
return Config(
2609-
tile_m=config.wi_tile_dlhs_batch_seq,
2610-
tile_k=config.wi_tile_dlhs_mlp_dim,
2611-
tile_n=config.wi_tile_dlhs_embed_dim,
2612-
)
2613-
elif dims == DRHS_RAGGED_DOT_DIM_NUMS:
2614-
return Config(
2615-
tile_m=config.wi_tile_drhs_batch_seq,
2616-
tile_k=config.wi_tile_drhs_embed_dim,
2617-
tile_n=config.wi_tile_drhs_mlp_dim,
2618-
)
2619-
2620-
return Config()
2621-
2622-
# Apply class-level monkey patch!
2623-
# pylint: disable=protected-access
2624-
PallasMosaicTpuRaggedDot._get_heuristics_config = custom_heuristics
2625-
_heuristics_patched = True
2626-
print("[TOKAMAX_PATCH] Successfully monkey-patched Tokamax GMM heuristics globally!")

0 commit comments

Comments
 (0)