Skip to content

Commit 17800bf

Browse files
committed
Explicitly pass qwix config for deepseek batch split
1 parent 00ef5de commit 17800bf

3 files changed

Lines changed: 149 additions & 19 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2525,6 +2525,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25252525
self.use_grpo = True
25262526
else:
25272527
self.use_grpo = False
2528+
2529+
if self.use_batch_split_schedule:
2530+
if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm):
2531+
raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`")
2532+
if self.quantization and not (self.use_qwix_quantization and self.quantization=="fp8_full"):
2533+
raise ValueError(
2534+
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
2535+
)
2536+
25282537
if self.opt_type == "muon" and self.decoder_block not in [
25292538
DecoderBlockType.DEEPSEEK,
25302539
DecoderBlockType.QWEN3,
@@ -2533,7 +2542,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25332542
]:
25342543
raise ValueError(
25352544
"Muon dimension numbers haven't been tested for this model. Run this command first: "
2536-
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
2545+
f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`"
25372546
)
25382547
if self.force_q_layout and not self.use_jax_splash:
25392548
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")

src/maxtext/layers/quantizations.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import functools
1818
import json
1919
import re
20-
from typing import Tuple, Sequence
20+
from typing import Tuple, Sequence, Callable
2121
from dataclasses import dataclass
2222

2323
from aqt.jax.v2 import config as aqt_config
@@ -27,6 +27,7 @@
2727
from aqt.jax.v2 import calibration
2828

2929
import qwix
30+
from qwix._src.core import dot_general_qt
3031

3132
import jax
3233
import jax.numpy as jnp
@@ -113,7 +114,7 @@ def _rhs_axis_metadata_wrapper(
113114

114115

115116
@dataclass
116-
class AqtQuantization:
117+
class AqtQuantization(Quantization):
117118
"""Configures AQT quantization github.com/google/aqt."""
118119

119120
quant_dg: aqt_config.DotGeneral
@@ -194,6 +195,83 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()):
194195
return aqt_einsum
195196

196197

198+
@dataclass
199+
class QwixQuantization(Quantization):
200+
"""Configures Qwix quantization github.com/google/qwix, for training only."""
201+
202+
quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN # needed by external call
203+
act_calibration_method: str = "absmax"
204+
weight_calibration_method: str = "absmax"
205+
bwd_calibration_method: str = "absmax"
206+
207+
def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig:
208+
"""Centralized factory for the Qwix dot_general config."""
209+
return dot_general_qt.DotGeneralQtConfig(
210+
lhs_qtype=jnp.float8_e4m3fn, # activation
211+
rhs_qtype=jnp.float8_e4m3fn, # weight
212+
dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient
213+
drhs_grad_qtype=jnp.float8_e5m2, # weight gradient
214+
lhs_calibration_method=self.act_calibration_method,
215+
rhs_calibration_method=self.weight_calibration_method,
216+
dlhs_grad_calibration_method=self.bwd_calibration_method,
217+
drhs_grad_calibration_method=self.bwd_calibration_method,
218+
tile_size=None,
219+
)
220+
221+
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
222+
"""Returns qwix dot_general."""
223+
return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())
224+
225+
def einsum(self, mesh_axes: Tuple[str, ...] = ()):
226+
"""Returns qwix eqinsum."""
227+
return QwixEinsum(config=self._get_fp8_full_qwix_config())
228+
229+
230+
class QwixDotGeneral(nn.Module):
231+
"""A callable class for Qwix dot_general."""
232+
config: dot_general_qt.DotGeneralQtConfig
233+
234+
@nn.compact
235+
def __call__(
236+
self,
237+
lhs: jax.Array,
238+
rhs: jax.Array,
239+
dimension_numbers: jax.lax.DotDimensionNumbers,
240+
precision: jax.lax.PrecisionLike = None,
241+
preferred_element_type: jax.typing.DTypeLike | None = None,
242+
*,
243+
out_sharding=None,
244+
) -> jax.Array:
245+
246+
return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)
247+
248+
249+
class QwixEinsum(nn.Module):
250+
"""A callable class for Qwix einsum."""
251+
config: dot_general_qt.DotGeneralQtConfig
252+
253+
@nn.compact
254+
def __call__(
255+
self,
256+
einsum_str: str,
257+
*operands: jax.Array,
258+
precision: jax.lax.PrecisionLike = None,
259+
preferred_element_type: jax.typing.DTypeLike | None = None,
260+
_dot_general: Callable[..., jax.Array] = None,
261+
out_sharding=None,
262+
) -> jax.Array:
263+
custom_dot_general = lambda *args, **kwargs: dot_general_qt.dot_general_qt(*args[:3], self.config)
264+
with jax.disable_jit():
265+
return jnp.einsum(
266+
einsum_str,
267+
*operands,
268+
precision=precision,
269+
preferred_element_type=preferred_element_type,
270+
_dot_general=custom_dot_general,
271+
out_sharding=out_sharding,
272+
)
273+
274+
197275
@dataclass
198276
class Fp8Quantization(Quantization):
199277
"""Configures Fp8 quantization for NVIDIA GPUs"""
@@ -546,6 +624,15 @@ def get_quant_mode(quant_mode_str: str = "train"):
546624

547625
def configure_quantization(config: Config, quant_mode_str: str = "train"):
548626
"""Configure quantization based on user config and quant mode."""
627+
if config.use_batch_split_schedule and config.quantization:
628+
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
629+
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
630+
return QwixQuantization(
631+
weight_calibration_method=config.weight_quantization_calibration_method,
632+
act_calibration_method=config.act_quantization_calibration_method,
633+
bwd_calibration_method=config.bwd_quantization_calibration_method,
634+
)
635+
549636
if config.use_qwix_quantization:
550637
return None
551638
quant_cfg = _get_quant_config(config)
@@ -726,7 +813,8 @@ def get_qt_provider(config):
726813

727814
def maybe_quantize_model(model, config):
728815
"""Quantize the model if quantization is enabled."""
729-
if config.use_qwix_quantization:
816+
# Batch split is not using Qwix's interception feature but manual plumbing
817+
if config.use_qwix_quantization and not config.use_batch_split_schedule:
730818
quantization_provider = get_qt_provider(config)
731819
if quantization_provider:
732820
model = qwix.quantize_model(model, quantization_provider)

0 commit comments

Comments
 (0)