Skip to content

Commit b88c6ac

Browse files
committed
clean tokamax gmm and add test
1 parent a8505ff commit b88c6ac

6 files changed

Lines changed: 170 additions & 81 deletions

File tree

src/maxtext/kernels/megablox/ops.py

Lines changed: 50 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def gmm(
6464
rhs_vma_axes: tuple = tuple(),
6565
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
6666
qwix_rule: qwix.QtRule | None = None,
67-
use_manual_quantization: bool = False,
67+
use_manual_quantization: bool = False, # used in batchsplit
6868
):
6969
"""Grouped matrix multiplication operation."""
7070
quantization_rule = None
@@ -163,10 +163,10 @@ def _gmm_fwd(
163163
else:
164164
rhs = quantizations.manual_quantize(
165165
rhs,
166-
quantization_rule.weight_calibration_method,
167166
quantization_rule.weight_qtype,
167+
calibration_method=quantization_rule.weight_calibration_method,
168168
)
169-
# QAG is only supported for following conditions
169+
# QAG is only supported for following conditions
170170
if use_tokamax_backend:
171171
if quantization_rule and quantization_rule.bwd_qtype:
172172
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
@@ -178,27 +178,23 @@ def _gmm_fwd(
178178
if transpose_rhs:
179179
rhs = rhs.swapaxes(1, 2)
180180

181+
# manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12
182+
out_kwargs = {}
181183
if use_manual_quantization:
182-
out = tokamax.ragged_dot(
183-
lhs=lhs,
184-
rhs=rhs,
185-
group_sizes=group_sizes,
186-
precision=jax.lax.Precision.DEFAULT,
187-
preferred_element_type=preferred_element_type,
188-
group_offset=group_offset,
189-
implementation="mosaic",
190-
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])),
191-
)
192-
else:
193-
out = tokamax.ragged_dot(
194-
lhs=lhs,
195-
rhs=rhs,
196-
group_sizes=group_sizes,
197-
precision=jax.lax.Precision.DEFAULT,
198-
preferred_element_type=preferred_element_type,
199-
group_offset=group_offset,
200-
implementation="mosaic",
201-
)
184+
# used in batchsplit
185+
out_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"]))
186+
187+
out = tokamax.ragged_dot(
188+
lhs=lhs,
189+
rhs=rhs,
190+
group_sizes=group_sizes,
191+
precision=jax.lax.Precision.DEFAULT,
192+
preferred_element_type=preferred_element_type,
193+
# `group_offset` is not yet supported
194+
group_offset=None,
195+
implementation="mosaic",
196+
**out_kwargs,
197+
)
202198
else:
203199
out = backend.gmm(
204200
lhs,
@@ -284,53 +280,39 @@ def _gmm_bwd(
284280
if not transpose_rhs:
285281
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)
286282

283+
# manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12
284+
dlhs_kwargs = {}
285+
drhs_kwargs = {}
287286
if use_manual_quantization:
288-
dlhs = tokamax.ragged_dot(
289-
lhs=dlhs_dout,
290-
rhs=dlhs_rhs,
291-
group_sizes=group_sizes,
292-
precision=jax.lax.Precision.DEFAULT,
293-
preferred_element_type=lhs_dtype,
294-
group_offset=group_offset,
295-
implementation="mosaic",
296-
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])),
297-
)
298-
else:
299-
dlhs = tokamax.ragged_dot(
300-
lhs=dlhs_dout,
301-
rhs=dlhs_rhs,
302-
group_sizes=group_sizes,
303-
precision=jax.lax.Precision.DEFAULT,
304-
preferred_element_type=lhs_dtype,
305-
group_offset=group_offset,
306-
implementation="mosaic",
307-
)
308-
if use_manual_quantization:
309-
drhs = tokamax.ragged_dot_general(
310-
lhs=lhs,
311-
rhs=drhs_dout,
312-
group_sizes=group_sizes,
313-
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
314-
precision=jax.lax.Precision.DEFAULT,
315-
preferred_element_type=rhs_dtype,
316-
group_offset=group_offset,
317-
implementation="mosaic",
318-
manual_axis_type=jax.sharding.ManualAxisType(
319-
varying=frozenset(["expert"]),
320-
unreduced=frozenset(["data", "fsdp"]),
321-
),
322-
)
323-
else:
324-
drhs = tokamax.ragged_dot_general(
325-
lhs=lhs,
326-
rhs=drhs_dout,
327-
group_sizes=group_sizes,
328-
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
329-
precision=jax.lax.Precision.DEFAULT,
330-
preferred_element_type=rhs_dtype,
331-
group_offset=group_offset,
332-
implementation="mosaic",
287+
# used in batchsplit
288+
dlhs_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"]))
289+
drhs_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(
290+
varying=frozenset(["expert"]), unreduced=frozenset(["data", "fsdp"])
333291
)
292+
293+
dlhs = tokamax.ragged_dot(
294+
lhs=dlhs_dout,
295+
rhs=dlhs_rhs,
296+
group_sizes=group_sizes,
297+
precision=jax.lax.Precision.DEFAULT,
298+
preferred_element_type=lhs_dtype,
299+
# `group_offset` is not yet supported
300+
group_offset=None,
301+
implementation="mosaic",
302+
**dlhs_kwargs,
303+
)
304+
drhs = tokamax.ragged_dot_general(
305+
lhs=lhs,
306+
rhs=drhs_dout,
307+
group_sizes=group_sizes,
308+
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
309+
precision=jax.lax.Precision.DEFAULT,
310+
preferred_element_type=rhs_dtype,
311+
# `group_offset` is not yet supported
312+
group_offset=None,
313+
implementation="mosaic",
314+
**drhs_kwargs,
315+
)
334316
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
335317
# Scatter back in reverse order of gather
336318
for axis_name, axis_idx in reversed(weight_gather_axes):

src/maxtext/layers/moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,8 @@ def extract_vma(tensor):
13211321
precision=jax.lax.Precision.DEFAULT,
13221322
preferred_element_type=self.dtype,
13231323
implementation="mosaic",
1324-
group_offset=group_offset,
1324+
# `group_offset` is not yet supported
1325+
group_offset=None,
13251326
)
13261327
elif self.config.megablox: # Older forked megablox
13271328
output = mblx.gmm(

src/maxtext/layers/quantizations.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def __call__(
245245
*,
246246
out_sharding=None,
247247
) -> jax.Array:
248-
249248
return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)
250249

251250

@@ -264,7 +263,6 @@ def __call__(
264263
_dot_general: Callable[..., jax.Array] | None = None,
265264
out_sharding=None,
266265
) -> jax.Array:
267-
268266
def custom_dot_general(*args, **kwargs):
269267
return dot_general_qt.dot_general_qt(*args[:3], self.config)
270268

@@ -509,9 +507,14 @@ def _get_aqt_fp8_default_config(config):
509507
constant_bound_config = None
510508

511509
if len(config.constant_bound_config) == 6:
512-
fwd_lhs_bound, fwd_rhs_bound, dlhs_lhs_bound, dlhs_rhs_bound, drhs_lhs_bound, drhs_rhs_bound = (
513-
config.constant_bound_config
514-
)
510+
(
511+
fwd_lhs_bound,
512+
fwd_rhs_bound,
513+
dlhs_lhs_bound,
514+
dlhs_rhs_bound,
515+
drhs_lhs_bound,
516+
drhs_rhs_bound,
517+
) = config.constant_bound_config
515518
constant_bound_config = ConstantBoundConfig(
516519
fwd_lhs_bound=fwd_lhs_bound,
517520
fwd_rhs_bound=fwd_rhs_bound,
@@ -839,26 +842,28 @@ def _get_max_min(target_dtype):
839842
return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16)
840843

841844

842-
def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn):
845+
def manual_quantize(tensor: jax.Array, dtype: jax.typing.DTypeLike, calibration_method: str) -> qwix.QArray:
843846
"""Manually quantizes a tensor based on a fixed calibration method.
844847
845848
Args:
846849
tensor: The tensor to quantize.
850+
dtype: The logical type of the quantized value, e.g. jnp.float8_e4m3fn
847851
calibration_method: A string specifying the calibration method. Expected
848-
format is "fixed,{scale},{max_val}".
852+
format is "fixed,{scale},{max_val}". e.g., "fixed,-224,224"
849853
850854
Returns:
851855
A qwix.QArray containing the quantized value and the scale.
852856
853857
Raises:
854858
ValueError: If calibration_method is None or has an unexpected format.
855859
"""
860+
# validate calibration method and parse
856861
calib_method = calibration_method
857862
if calib_method is None:
858863
raise ValueError("calibration_method cannot be None for manual quantization")
859864
if not calib_method.startswith("fixed"):
860-
raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}")
861-
865+
# we can use static scale for weight/activation, but grad usually needs dynamic
866+
raise ValueError("Only static scale quantization is supported, but got" f" {calib_method}")
862867
parts = calib_method.split(",")
863868
if len(parts) != 3:
864869
raise ValueError(f"Unexpected format for weight calibration method: {calib_method}")

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""
2222

2323
import pytest
24+
import sys
2425
import warnings
2526

2627
warnings.filterwarnings(
@@ -32,6 +33,14 @@
3233
warnings.filterwarnings(
3334
"ignore", message="builtin type SwigPyObject has no __module__ attribute", category=DeprecationWarning
3435
)
36+
37+
# Prevent libraries that use absl flags (e.g. tokamax) from lazily parsing sys.argv,
38+
# which would pick up pytest flags like `-v -m` and fail to parse them as integers.
39+
from absl import flags as _absl_flags
40+
41+
if not _absl_flags.FLAGS.is_parsed():
42+
_absl_flags.FLAGS(sys.argv[:1])
43+
3544
import jax
3645
import os
3746
import importlib.util

tests/integration/sparsity_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Smoke test for sparsity.
15-
"""
14+
15+
"""Smoke test for sparsity."""
1616

1717
import os
1818
import tempfile
@@ -28,7 +28,7 @@
2828

2929
@pytest.mark.integration_test
3030
class Train(parameterized.TestCase):
31-
"""Smoke test for sparsity in G3 only."""
31+
"""Smoke test for sparsity."""
3232

3333
@parameterized.named_parameters(
3434
{

tests/integration/tokamax_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test for tokamax gmm and splash."""
16+
17+
import os
18+
import tempfile
19+
from absl.testing import absltest
20+
from absl.testing import parameterized
21+
import pytest
22+
from maxtext.trainers.pre_train import train
23+
from tests.utils.test_helpers import get_test_config_path
24+
25+
train_main = train.main
26+
gettempdir = tempfile.gettempdir
27+
28+
29+
@pytest.mark.integration_test
30+
class Train(parameterized.TestCase):
31+
"""Test for tokamax gmm and splash."""
32+
33+
@parameterized.named_parameters(
34+
{
35+
"testcase_name": "gmm bf16",
36+
"quantization": "",
37+
},
38+
{
39+
"testcase_name": "gmm fp8",
40+
"quantization": "fp8_full",
41+
},
42+
)
43+
@pytest.mark.tpu_only
44+
def test_different_configs(self, quantization: str):
45+
"""Smoke train with small config."""
46+
test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir())
47+
outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir)
48+
args = [
49+
None,
50+
get_test_config_path(),
51+
f"base_output_directory={test_tmpdir}",
52+
"run_name=tokamax_test",
53+
# model
54+
"base_emb_dim=256",
55+
"base_num_query_heads=1",
56+
"base_num_kv_heads=1",
57+
"base_mlp_dim=256",
58+
"base_moe_mlp_dim=256",
59+
"base_num_decoder_layers=2",
60+
"head_dim=64",
61+
"decoder_block=deepseek",
62+
"attention_type=mla",
63+
"num_experts=2",
64+
"shared_experts=1",
65+
# tokamax gmm
66+
"sparse_matmul=True",
67+
"megablox=False",
68+
"use_tokamax_gmm=True",
69+
# tokamax splash
70+
"max_target_length=128",
71+
"attention=flash",
72+
"use_tokamax_splash=True",
73+
# quantization
74+
f"quantization={quantization}",
75+
"use_qwix_quantization=True",
76+
"weight_quantization_calibration_method=fixed,-224,224",
77+
"act_quantization_calibration_method=fixed,-224,224",
78+
# train
79+
"per_device_batch_size=1",
80+
"dataset_type=synthetic",
81+
"steps=3",
82+
"enable_checkpointing=False",
83+
"enable_goodput_recording=False",
84+
"enable_checkpoint_cloud_logger=False",
85+
"monitor_goodput=False",
86+
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
87+
]
88+
train_main(args)
89+
90+
91+
if __name__ == "__main__":
92+
absltest.main()

0 commit comments

Comments
 (0)