Skip to content

Commit 64a7746

Browse files
CaptainO5Google-ML-Automation
authored andcommitted
Refactor Megablox Ops to use public Tokamax API
PiperOrigin-RevId: 900763560
1 parent 0ba93e2 commit 64a7746

5 files changed

Lines changed: 55 additions & 102 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,6 @@ wo_tile_drhs_batch_seq: 512
223223
wo_tile_drhs_embed_dim: 1024
224224
wo_tile_drhs_mlp_dim: 1024
225225

226-
wi_tile_fwd_buffer_count: 2
227-
wi_tile_dlhs_buffer_count: 2
228-
wi_tile_drhs_buffer_count: 2
229-
wo_tile_fwd_buffer_count: 2
230-
wo_tile_dlhs_buffer_count: 2
231-
wo_tile_drhs_buffer_count: 2
232-
233-
wi_combine_scopes: False
234-
wo_combine_scopes: False
235-
236226
merge_gating_gmm: False
237227

238228
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.

src/maxtext/configs/types.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -734,16 +734,6 @@ class MoEKernels(BaseModel):
734734
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
735735
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")
736736

737-
wi_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wi.")
738-
wi_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wi.")
739-
wi_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wi.")
740-
wo_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wo.")
741-
wo_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wo.")
742-
wo_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wo.")
743-
744-
wi_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wi.")
745-
wo_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wo.")
746-
747737
merge_gating_gmm: bool = Field(False, description="whether to merge the two gating gmm kernels into one.")
748738

749739

src/maxtext/kernels/megablox/ops.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,40 @@
1616

1717
# pylint: disable=too-many-positional-arguments
1818

19-
import functools
2019
import dataclasses
21-
from typing import Literal, List, Tuple
20+
import functools
21+
from typing import List, Literal, Tuple
2222
import jax
2323
import jax.numpy as jnp
2424
from maxtext.kernels.megablox import backend
25-
from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
2625
import qwix
2726
import qwix.pallas as qpl
27+
import tokamax
28+
29+
30+
DRHS_RAGGED_DOT_DIM_NUMS = jax.lax.RaggedDotDimensionNumbers(
31+
dot_dimension_numbers=(([0], [0]), ([], [])),
32+
lhs_ragged_dimensions=[0],
33+
rhs_group_dimensions=[],
34+
)
2835

2936

3037
def gmm(
3138
lhs: jnp.ndarray,
3239
rhs: jnp.ndarray,
3340
group_sizes: jnp.ndarray,
3441
preferred_element_type: jnp.dtype = jnp.float32,
35-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
42+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
43+
128,
44+
128,
45+
128,
46+
128,
47+
128,
48+
128,
49+
128,
50+
128,
51+
128,
52+
),
3653
group_offset: jnp.ndarray | None = None,
3754
existing_out: jnp.ndarray | None = None,
3855
transpose_rhs: bool = False,
@@ -42,8 +59,6 @@ def gmm(
4259
use_qwix_quantization: bool = False,
4360
use_tokamax_backend: bool = False,
4461
weight_gather_axes: List[Tuple[str, int]] | None = None,
45-
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
46-
combine_scopes: bool = False,
4762
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
4863
qwix_rule: qwix.QtRule | None = None,
4964
):
@@ -65,16 +80,14 @@ def gmm(
6580
)
6681

6782
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
68-
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
83+
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
6984
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
7085
return gmm_fwd_bwd(
7186
lhs,
7287
rhs,
7388
group_sizes,
7489
preferred_element_type,
7590
tiling,
76-
input_buffer_count,
77-
combine_scopes,
7891
group_offset,
7992
existing_out,
8093
transpose_rhs,
@@ -90,9 +103,17 @@ def _gmm_fwd(
90103
rhs: jnp.ndarray,
91104
group_sizes: jnp.ndarray,
92105
preferred_element_type: jnp.dtype = jnp.float32,
93-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
94-
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
95-
combine_scopes: bool = False,
106+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
107+
128,
108+
128,
109+
128,
110+
128,
111+
128,
112+
128,
113+
128,
114+
128,
115+
128,
116+
),
96117
group_offset: jnp.ndarray | None = None,
97118
existing_out: jnp.ndarray | None = None,
98119
transpose_rhs: bool = False,
@@ -136,17 +157,18 @@ def _gmm_fwd(
136157
for axis_name, axis_idx in weight_gather_axes:
137158
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
138159
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
139-
out = tokamax_backend.gmm(
160+
# Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
161+
if transpose_rhs:
162+
rhs = rhs.swapaxes(1, 2)
163+
164+
out = tokamax.ragged_dot(
140165
lhs=lhs,
141166
rhs=rhs,
142167
group_sizes=group_sizes,
143168
precision=jax.lax.Precision.DEFAULT,
144-
out_dtype=preferred_element_type,
145-
tiling=tiling[:3],
169+
preferred_element_type=preferred_element_type,
146170
group_offset=group_offset,
147-
transpose_rhs=transpose_rhs,
148-
interpret=interpret,
149-
input_buffer_count=input_buffer_count[0],
171+
implementation="mosaic",
150172
)
151173
else:
152174
out = backend.gmm(
@@ -168,8 +190,6 @@ def _gmm_bwd(
168190
rhs_dtype: jax.typing.DTypeLike,
169191
preferred_element_type: jnp.dtype,
170192
tiling: tuple[int, int, int, int, int, int, int, int, int],
171-
input_buffer_count: tuple[int, int, int],
172-
combine_scopes: bool,
173193
transpose_rhs: bool,
174194
interpret: bool,
175195
quantization_rule: qwix.QtRule | None,
@@ -224,30 +244,29 @@ def _gmm_bwd(
224244
calibration_method=quantization_rule.bwd_calibration_method,
225245
)
226246
if use_tokamax_backend:
227-
dlhs = tokamax_backend.gmm(
247+
# Handle transpose_rhs manually
248+
dlhs_rhs = rhs
249+
if not transpose_rhs:
250+
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)
251+
252+
dlhs = tokamax.ragged_dot(
228253
lhs=dlhs_dout,
229-
rhs=rhs,
254+
rhs=dlhs_rhs,
230255
group_sizes=group_sizes,
231256
precision=jax.lax.Precision.DEFAULT,
232-
out_dtype=lhs_dtype,
233-
tiling=tiling[3:6],
257+
preferred_element_type=lhs_dtype,
234258
group_offset=group_offset,
235-
transpose_rhs=not transpose_rhs,
236-
interpret=interpret,
237-
input_buffer_count=input_buffer_count[1],
259+
implementation="mosaic",
238260
)
239-
drhs = tokamax_backend.tgmm(
240-
lhs=lhs.swapaxes(0, 1),
261+
drhs = tokamax.ragged_dot_general(
262+
lhs=lhs,
241263
rhs=drhs_dout,
242264
group_sizes=group_sizes,
265+
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
243266
precision=jax.lax.Precision.DEFAULT,
244-
out_dtype=rhs_dtype,
245-
tiling=tiling[-3:],
267+
preferred_element_type=rhs_dtype,
246268
group_offset=group_offset,
247-
num_actual_groups=num_actual_groups,
248-
interpret=interpret,
249-
input_buffer_count=input_buffer_count[2],
250-
combine_scopes=combine_scopes,
269+
implementation="mosaic",
251270
)
252271
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
253272
# Scatter back in reverse order of gather

src/maxtext/layers/moe.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -965,9 +965,7 @@ def get_quantization_dtypes():
965965
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
966966
return lhs_quantize_dtype, rhs_quantize_dtype
967967

968-
def gmm(
969-
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
970-
):
968+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
971969
if inputs.shape[0] != expert_assignments.shape[0]:
972970
raise ValueError("The number of input tokens must match the number of expert assignments!")
973971

@@ -993,8 +991,6 @@ def gmm(
993991
use_qwix_quantization=self.config.use_qwix_quantization,
994992
use_tokamax_backend=self.config.use_tokamax_gmm,
995993
weight_gather_axes=weight_gather_axes,
996-
input_buffer_count=input_buffer_count,
997-
combine_scopes=combine_scopes,
998994
)
999995
else: # tokamax (unquantized)
1000996
output = tokamax.ragged_dot(
@@ -1250,26 +1246,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12501246
self.config.wo_tile_drhs_embed_dim,
12511247
self.config.wo_tile_drhs_mlp_dim,
12521248
)
1253-
wi_input_buffer_count = (
1254-
self.config.wi_tile_fwd_buffer_count,
1255-
self.config.wi_tile_dlhs_buffer_count,
1256-
self.config.wi_tile_drhs_buffer_count,
1257-
)
1258-
wo_input_buffer_count = (
1259-
self.config.wo_tile_fwd_buffer_count,
1260-
self.config.wo_tile_dlhs_buffer_count,
1261-
self.config.wo_tile_drhs_buffer_count,
1262-
)
12631249

1264-
wi_combine_scopes = self.config.wi_combine_scopes
1265-
wo_combine_scopes = self.config.wo_combine_scopes
12661250
layer_w0 = gmm_fn(
12671251
x,
12681252
w0,
12691253
tiling=wi_tile_size,
12701254
weight_gather_axes=wi_gather_axes,
1271-
input_buffer_count=wi_input_buffer_count,
1272-
combine_scopes=wi_combine_scopes,
12731255
)
12741256
if self.get_tensor_transpose_parallelism_size() > 1:
12751257
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
@@ -1282,8 +1264,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12821264
w1,
12831265
tiling=wi_tile_size,
12841266
weight_gather_axes=wi_gather_axes,
1285-
input_buffer_count=wi_input_buffer_count,
1286-
combine_scopes=wi_combine_scopes,
12871267
)
12881268
if self.get_tensor_transpose_parallelism_size() > 1:
12891269
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
@@ -1297,8 +1277,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12971277
wo,
12981278
tiling=wo_tile_size,
12991279
weight_gather_axes=wo_gather_axes,
1300-
input_buffer_count=wo_input_buffer_count,
1301-
combine_scopes=wo_combine_scopes,
13021280
)
13031281
if self.get_tensor_parallelism_size() > 1:
13041282
intermediate_output = jax.lax.psum_scatter(

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,6 @@ def gmm(
948948
group_sizes,
949949
preferred_element_type,
950950
weight_gather_axes,
951-
input_buffer_count,
952-
combine_scopes,
953951
):
954952
if config.use_qwix_quantization:
955953
output = megablox.gmm(
@@ -961,8 +959,6 @@ def gmm(
961959
use_qwix_quantization=config.use_qwix_quantization,
962960
use_tokamax_backend=config.use_tokamax_gmm,
963961
weight_gather_axes=weight_gather_axes,
964-
input_buffer_count=input_buffer_count,
965-
combine_scopes=combine_scopes,
966962
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
967963
)
968964
else:
@@ -1002,19 +998,7 @@ def gmm(
1002998
config.wo_tile_drhs_embed_dim,
1003999
config.wo_tile_drhs_mlp_dim,
10041000
)
1005-
wi_input_buffer_count = (
1006-
config.wi_tile_fwd_buffer_count,
1007-
config.wi_tile_dlhs_buffer_count,
1008-
config.wi_tile_drhs_buffer_count,
1009-
)
1010-
wo_input_buffer_count = (
1011-
config.wo_tile_fwd_buffer_count,
1012-
config.wo_tile_dlhs_buffer_count,
1013-
config.wo_tile_drhs_buffer_count,
1014-
)
10151001

1016-
wi_combine_scopes = config.wi_combine_scopes
1017-
wo_combine_scopes = config.wo_combine_scopes
10181002
if config.use_qwix_quantization:
10191003
gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes()
10201004
w0_pspec = nn.logical_to_mesh_axes(gating_pspec)
@@ -1043,8 +1027,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10431027
w01,
10441028
tiling=wi_tile_size,
10451029
weight_gather_axes=wi_gather_axes,
1046-
input_buffer_count=wi_input_buffer_count,
1047-
combine_scopes=wi_combine_scopes,
10481030
)
10491031
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
10501032
else:
@@ -1053,16 +1035,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10531035
w0,
10541036
tiling=wi_tile_size,
10551037
weight_gather_axes=wi_gather_axes,
1056-
input_buffer_count=wi_input_buffer_count,
1057-
combine_scopes=wi_combine_scopes,
10581038
)
10591039
layer_w1 = gmm_fn(
10601040
x,
10611041
w1,
10621042
tiling=wi_tile_size,
10631043
weight_gather_axes=wi_gather_axes,
1064-
input_buffer_count=wi_input_buffer_count,
1065-
combine_scopes=wi_combine_scopes,
10661044
)
10671045
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
10681046
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
@@ -1073,8 +1051,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10731051
wo,
10741052
tiling=wo_tile_size,
10751053
weight_gather_axes=wo_gather_axes,
1076-
input_buffer_count=wo_input_buffer_count,
1077-
combine_scopes=wo_combine_scopes,
10781054
)
10791055
return layer_wo
10801056

0 commit comments

Comments
 (0)