Skip to content

Commit cd2d86f

Browse files
committed
Explicitly pass qwix config for deepseek batch split
1 parent ca7e2df commit cd2d86f

3 files changed

Lines changed: 152 additions & 27 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2527,6 +2527,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25272527
self.use_grpo = True
25282528
else:
25292529
self.use_grpo = False
2530+
2531+
if self.use_batch_split_schedule:
2532+
if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm):
2533+
raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`")
2534+
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
2535+
raise ValueError(
2536+
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
2537+
)
2538+
25302539
if self.opt_type == "muon" and self.decoder_block not in [
25312540
DecoderBlockType.DEEPSEEK,
25322541
DecoderBlockType.QWEN3,
@@ -2535,7 +2544,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25352544
]:
25362545
raise ValueError(
25372546
"Muon dimension numbers haven't been tested for this model. Run this command first: "
2538-
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
2547+
f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`"
25392548
)
25402549
if self.force_q_layout and not self.use_jax_splash:
25412550
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")

src/maxtext/layers/quantizations.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import functools
1818
import json
1919
import re
20-
from typing import Tuple, Sequence
20+
from typing import Tuple, Sequence, Callable
2121
from dataclasses import dataclass
2222

2323
from aqt.jax.v2 import config as aqt_config
@@ -27,6 +27,7 @@
2727
from aqt.jax.v2 import calibration
2828

2929
import qwix
30+
from qwix._src.core import dot_general_qt
3031

3132
import jax
3233
import jax.numpy as jnp
@@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()):
194195
return aqt_einsum
195196

196197

198+
@dataclass
199+
class QwixQuantization:
200+
"""Configures Qwix quantization github.com/google/qwix, for training only."""
201+
202+
quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN # needed by external call
203+
act_calibration_method: str = "absmax"
204+
weight_calibration_method: str = "absmax"
205+
bwd_calibration_method: str = "absmax"
206+
207+
def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig:
208+
"""Returns Qwix dot_general config for fp8_full quantization."""
209+
return dot_general_qt.DotGeneralQtConfig(
210+
lhs_qtype=jnp.float8_e4m3fn, # activation
211+
rhs_qtype=jnp.float8_e4m3fn, # weight
212+
dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient
213+
drhs_grad_qtype=jnp.float8_e5m2, # weight gradient
214+
lhs_calibration_method=self.act_calibration_method,
215+
rhs_calibration_method=self.weight_calibration_method,
216+
dlhs_grad_calibration_method=self.bwd_calibration_method,
217+
drhs_grad_calibration_method=self.bwd_calibration_method,
218+
tile_size=None,
219+
)
220+
221+
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
222+
"""Returns Qwix dot_general."""
223+
return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())
224+
225+
def einsum(self, mesh_axes: Tuple[str, ...] = ()):
226+
"""Returns Qwix einsum."""
227+
return QwixEinsum(config=self._get_fp8_full_qwix_config())
228+
229+
230+
class QwixDotGeneral(nn.Module):
231+
"""A callable class for Qwix dot_general."""
232+
233+
config: dot_general_qt.DotGeneralQtConfig
234+
235+
@nn.compact
236+
def __call__(
237+
self,
238+
lhs: jax.Array,
239+
rhs: jax.Array,
240+
dimension_numbers: jax.lax.DotDimensionNumbers,
241+
precision: jax.lax.PrecisionLike = None,
242+
preferred_element_type: jax.typing.DTypeLike | None = None,
243+
*,
244+
out_sharding=None,
245+
) -> jax.Array:
246+
247+
return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)
248+
249+
250+
class QwixEinsum(nn.Module):
251+
"""A callable class for Qwix einsum."""
252+
253+
config: dot_general_qt.DotGeneralQtConfig
254+
255+
@nn.compact
256+
def __call__(
257+
self,
258+
einsum_str: str,
259+
*operands: jax.Array,
260+
precision: jax.lax.PrecisionLike = None,
261+
preferred_element_type: jax.typing.DTypeLike | None = None,
262+
_dot_general: Callable[..., jax.Array] | None = None,
263+
out_sharding=None,
264+
) -> jax.Array:
265+
266+
def custom_dot_general(*args, **kwargs):
267+
return dot_general_qt.dot_general_qt(*args[:3], self.config)
268+
269+
with jax.disable_jit():
270+
return jnp.einsum(
271+
einsum_str,
272+
*operands,
273+
precision=precision,
274+
preferred_element_type=preferred_element_type,
275+
_dot_general=custom_dot_general,
276+
out_sharding=out_sharding,
277+
)
278+
279+
197280
@dataclass
198281
class Fp8Quantization(Quantization):
199282
"""Configures Fp8 quantization for NVIDIA GPUs"""
@@ -546,6 +629,15 @@ def get_quant_mode(quant_mode_str: str = "train"):
546629

547630
def configure_quantization(config: Config, quant_mode_str: str = "train"):
548631
"""Configure quantization based on user config and quant mode."""
632+
if config.use_batch_split_schedule and config.quantization:
633+
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
634+
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
635+
return QwixQuantization(
636+
weight_calibration_method=config.weight_quantization_calibration_method,
637+
act_calibration_method=config.act_quantization_calibration_method,
638+
bwd_calibration_method=config.bwd_quantization_calibration_method,
639+
)
640+
549641
if config.use_qwix_quantization:
550642
return None
551643
quant_cfg = _get_quant_config(config)
@@ -726,7 +818,8 @@ def get_qt_provider(config):
726818

727819
def maybe_quantize_model(model, config):
728820
"""Quantize the model if quantization is enabled."""
729-
if config.use_qwix_quantization:
821+
# Batch split is not using Qwix's interception feature but manual plumbing
822+
if config.use_qwix_quantization and not config.use_batch_split_schedule:
730823
quantization_provider = get_qt_provider(config)
731824
if quantization_provider:
732825
model = qwix.quantize_model(model, quantization_provider)

0 commit comments

Comments
 (0)