Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,10 @@ jobs:
else
SPLIT_ARGS=""
fi
# TODO: Fix the skipped tests and remove the deselect flags
.venv/bin/python3 -m pytest ${INPUTS_PYTEST_ADDOPTS} \
-v \
-m "${FINAL_PYTEST_MARKER}" \
--durations=0 \
--deselect "tests/unit/tokenizer_test.py::TokenizerTest::test_detokenize" \
--cov=MaxText \
--cov=maxtext \
--cov-report=xml \
Expand Down
Binary file modified src/maxtext/assets/tokenizers/tokenizer.default
Binary file not shown.
11 changes: 10 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.use_grpo = True
else:
self.use_grpo = False

if self.use_batch_split_schedule:
if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm):
raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`")
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
raise ValueError(
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
)

if self.opt_type == "muon" and self.decoder_block not in [
DecoderBlockType.DEEPSEEK,
DecoderBlockType.QWEN3,
Expand All @@ -2539,7 +2548,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
]:
raise ValueError(
"Muon dimension numbers haven't been tested for this model. Run this command first: "
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`"
)
if self.force_q_layout and not self.use_jax_splash:
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")
Expand Down
101 changes: 96 additions & 5 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import json
import re
from typing import Tuple, Sequence
from typing import Tuple, Sequence, Callable
from dataclasses import dataclass

from aqt.jax.v2 import config as aqt_config
Expand All @@ -27,6 +27,7 @@
from aqt.jax.v2 import calibration

import qwix
from qwix._src.core import dot_general_qt

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()):
return aqt_einsum


@dataclass
class QwixQuantization:
"""Configures Qwix quantization github.com/google/qwix, for training only."""

quant_mode = "train" # needed by external call
act_calibration_method: str = "absmax"
weight_calibration_method: str = "absmax"
bwd_calibration_method: str = "absmax"

def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig:
"""Returns Qwix dot_general config for fp8_full quantization."""
return dot_general_qt.DotGeneralQtConfig(
lhs_qtype=jnp.float8_e4m3fn, # activation
rhs_qtype=jnp.float8_e4m3fn, # weight
dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient
drhs_grad_qtype=jnp.float8_e5m2, # weight gradient
lhs_calibration_method=self.act_calibration_method,
rhs_calibration_method=self.weight_calibration_method,
dlhs_grad_calibration_method=self.bwd_calibration_method,
drhs_grad_calibration_method=self.bwd_calibration_method,
tile_size=None,
)

def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns Qwix dot_general."""
return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())

def einsum(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns Qwix einsum."""
return QwixEinsum(config=self._get_fp8_full_qwix_config())


class QwixDotGeneral(nn.Module):
"""A callable class for Qwix dot_general."""

config: dot_general_qt.DotGeneralQtConfig

@nn.compact
def __call__(
self,
lhs: jax.Array,
rhs: jax.Array,
dimension_numbers: jax.lax.DotDimensionNumbers,
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None,
*,
out_sharding=None,
) -> jax.Array:

return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)


class QwixEinsum(nn.Module):
"""A callable class for Qwix einsum."""

config: dot_general_qt.DotGeneralQtConfig

@nn.compact
def __call__(
self,
einsum_str: str,
*operands: jax.Array,
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None,
_dot_general: Callable[..., jax.Array] | None = None,
out_sharding=None,
) -> jax.Array:

def custom_dot_general(*args, **kwargs):
return dot_general_qt.dot_general_qt(*args[:3], self.config)

with jax.disable_jit():
return jnp.einsum(
einsum_str,
*operands,
precision=precision,
preferred_element_type=preferred_element_type,
_dot_general=custom_dot_general,
out_sharding=out_sharding,
)


@dataclass
class Fp8Quantization(Quantization):
"""Configures Fp8 quantization for NVIDIA GPUs"""
Expand Down Expand Up @@ -539,13 +622,20 @@ def get_quant_mode(quant_mode_str: str = "train"):
return aqt_flax.QuantMode.SERVE
elif quant_mode_str == "convert":
return aqt_flax.QuantMode.CONVERT
else:
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")
return None
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")


def configure_quantization(config: Config, quant_mode_str: str = "train"):
"""Configure quantization based on user config and quant mode."""
if config.use_batch_split_schedule and config.quantization:
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
return QwixQuantization(
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
)

if config.use_qwix_quantization:
return None
quant_cfg = _get_quant_config(config)
Expand Down Expand Up @@ -726,7 +816,8 @@ def get_qt_provider(config):

def maybe_quantize_model(model, config):
"""Quantize the model if quantization is enabled."""
if config.use_qwix_quantization:
# Batch split is not using Qwix's interception feature but manual plumbing
if config.use_qwix_quantization and not config.use_batch_split_schedule:
quantization_provider = get_qt_provider(config)
if quantization_provider:
model = qwix.quantize_model(model, quantization_provider)
Expand Down
Loading
Loading