Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,6 +2529,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.use_grpo = True
else:
self.use_grpo = False

if self.use_batch_split_schedule:
if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm):
raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`")
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
raise ValueError(
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
)

if self.opt_type == "muon" and self.decoder_block not in [
DecoderBlockType.DEEPSEEK,
DecoderBlockType.QWEN3,
Expand All @@ -2537,7 +2546,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
]:
raise ValueError(
"Muon dimension numbers haven't been tested for this model. Run this command first: "
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`"
)
if self.force_q_layout and not self.use_jax_splash:
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")
Expand Down
101 changes: 96 additions & 5 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import json
import re
from typing import Tuple, Sequence
from typing import Tuple, Sequence, Callable
from dataclasses import dataclass

from aqt.jax.v2 import config as aqt_config
Expand All @@ -27,6 +27,7 @@
from aqt.jax.v2 import calibration

import qwix
from qwix._src.core import dot_general_qt

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()):
return aqt_einsum


@dataclass
class QwixQuantization:
"""Configures Qwix quantization github.com/google/qwix, for training only."""

quant_mode = "train" # needed by external call
act_calibration_method: str = "absmax"
weight_calibration_method: str = "absmax"
bwd_calibration_method: str = "absmax"

def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig:
"""Returns Qwix dot_general config for fp8_full quantization."""
return dot_general_qt.DotGeneralQtConfig(
lhs_qtype=jnp.float8_e4m3fn, # activation
rhs_qtype=jnp.float8_e4m3fn, # weight
dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient
drhs_grad_qtype=jnp.float8_e5m2, # weight gradient
lhs_calibration_method=self.act_calibration_method,
rhs_calibration_method=self.weight_calibration_method,
dlhs_grad_calibration_method=self.bwd_calibration_method,
drhs_grad_calibration_method=self.bwd_calibration_method,
tile_size=None,
)

def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns Qwix dot_general."""
return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())

def einsum(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns Qwix einsum."""
return QwixEinsum(config=self._get_fp8_full_qwix_config())


class QwixDotGeneral(nn.Module):
"""A callable class for Qwix dot_general."""

config: dot_general_qt.DotGeneralQtConfig

@nn.compact
def __call__(
self,
lhs: jax.Array,
rhs: jax.Array,
dimension_numbers: jax.lax.DotDimensionNumbers,
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None,
*,
out_sharding=None,
) -> jax.Array:

return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)


class QwixEinsum(nn.Module):
"""A callable class for Qwix einsum."""

config: dot_general_qt.DotGeneralQtConfig

@nn.compact
def __call__(
self,
einsum_str: str,
*operands: jax.Array,
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None,
_dot_general: Callable[..., jax.Array] | None = None,
out_sharding=None,
) -> jax.Array:

def custom_dot_general(*args, **kwargs):
return dot_general_qt.dot_general_qt(*args[:3], self.config)

with jax.disable_jit():
return jnp.einsum(
einsum_str,
*operands,
precision=precision,
preferred_element_type=preferred_element_type,
_dot_general=custom_dot_general,
out_sharding=out_sharding,
)


@dataclass
class Fp8Quantization(Quantization):
"""Configures Fp8 quantization for NVIDIA GPUs"""
Expand Down Expand Up @@ -539,13 +622,20 @@ def get_quant_mode(quant_mode_str: str = "train"):
return aqt_flax.QuantMode.SERVE
elif quant_mode_str == "convert":
return aqt_flax.QuantMode.CONVERT
else:
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")
return None
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")


def configure_quantization(config: Config, quant_mode_str: str = "train"):
"""Configure quantization based on user config and quant mode."""
if config.use_batch_split_schedule and config.quantization:
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
return QwixQuantization(
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
)

if config.use_qwix_quantization:
return None
quant_cfg = _get_quant_config(config)
Expand Down Expand Up @@ -726,7 +816,8 @@ def get_qt_provider(config):

def maybe_quantize_model(model, config):
"""Quantize the model if quantization is enabled."""
if config.use_qwix_quantization:
# Batch split is not using Qwix's interception feature but manual plumbing
if config.use_qwix_quantization and not config.use_batch_split_schedule:
quantization_provider = get_qt_provider(config)
if quantization_provider:
model = qwix.quantize_model(model, quantization_provider)
Expand Down
51 changes: 42 additions & 9 deletions src/maxtext/models/deepseek_batchsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def batch_split_schedule(
rope_factor=cfg.rope_factor,
mscale=cfg.mscale,
dtype=cfg.dtype,
quant=quant,
)

xs = moe(
Expand All @@ -297,6 +298,7 @@ def batch_split_schedule(
expert_axis_name="expert",
use_gather_mosaic_kernel=False,
config=cfg,
quant=quant,
)
return xs

Expand All @@ -319,7 +321,21 @@ def with_data_parallel_constraint(x, mesh):
return jax.lax.with_sharding_constraint(x, jax.NamedSharding(mesh, activation_pspec))


def dot(x, y, axes=1):
def dot(x, y, quant=None, axes=1):
"""Computes the dot product of two arrays, optionally using quantization."""
if quant is not None:
# Convert axes to jax.lax.dot_general dimension_numbers
if isinstance(axes, int):
x_contract = tuple(range(x.ndim - axes, x.ndim))
y_contract = tuple(range(axes))
else:
x_contract, y_contract = axes
dimension_numbers = ((x_contract, y_contract), ((), ()))
# Instantiate and call qwix dot_general
custom_dot = quant.dot_general_cls()()
return custom_dot(lhs=x, rhs=y, dimension_numbers=dimension_numbers)

# Unquantized
return jnp.tensordot(x, y, axes=axes)


Expand All @@ -345,6 +361,7 @@ def mla_with_norms(
rope_factor,
mscale,
dtype,
quant,
):
"""Performs MLA with pre- and post-normalization."""
(pre_attn_scale, post_attn_scale), attn_ws = weights
Expand Down Expand Up @@ -379,6 +396,7 @@ def fn(args):
dtype=dtype,
mscale=mscale,
attention_op_fn=attn_op,
quant=quant,
),
mesh,
)
Expand Down Expand Up @@ -414,6 +432,7 @@ def mla(
mscale,
attention_op_fn,
dtype,
quant,
):
"""Performs MLA."""
(
Expand Down Expand Up @@ -442,6 +461,7 @@ def mla(
dtype=dtype,
qk_nope_head_dim=qk_nope_head_dim,
mscale=mscale,
quant=quant,
)
query = jax.ad_checkpoint.checkpoint_name(query, "query_proj")
key, value = kv_projection(
Expand All @@ -462,6 +482,7 @@ def mla(
dtype=dtype,
qk_nope_head_dim=qk_nope_head_dim,
num_query_heads=num_query_heads,
quant=quant,
)
key = jax.ad_checkpoint.checkpoint_name(key, "key_proj")
value = jax.ad_checkpoint.checkpoint_name(value, "value_proj")
Expand All @@ -474,7 +495,7 @@ def mla(
cached_values=[None, None],
)
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
out = dot(out, out_weights, axes=2)
out = dot(out, out_weights, quant=quant, axes=2)
out = jax.ad_checkpoint.checkpoint_name(out, "out_proj")
return out

Expand All @@ -497,6 +518,7 @@ def query_projection(
rope_factor,
dtype,
mscale,
quant,
):
"""Performs query projection."""
# Set softmax scaling.
Expand All @@ -507,15 +529,15 @@ def query_projection(
softmax_scale = softmax_scale * m * m

# LoRA path
low_rank_q = dot(inputs_q, wq_a_weights)
low_rank_q = dot(inputs_q, wq_a_weights, quant=quant)
low_rank_q = rms_norm(
low_rank_q,
q_norm_scale_weights,
epsilon=epsilon,
dtype=dtype,
)
low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q")
q = dot(low_rank_q, wq_b_weights)
q = dot(low_rank_q, wq_b_weights, quant=quant)

# Split into non-positional and rotary parts.
q_nope, q_pe = jnp.split(q, [qk_nope_head_dim], axis=-1)
Expand Down Expand Up @@ -554,9 +576,10 @@ def kv_projection(
dtype,
qk_nope_head_dim,
num_query_heads,
quant,
):
"""Performs KV projection."""
low_rank = dot(inputs, wkv_a_weights)
low_rank = dot(inputs, wkv_a_weights, quant=quant)
low_rank_main, low_rank_rope = jnp.split(low_rank, [kv_lora_rank], axis=-1)
low_rank_main = rms_norm(
low_rank_main,
Expand Down Expand Up @@ -585,12 +608,13 @@ def kv_projection(
wkv_b_weights,
qk_nope_head_dim=qk_nope_head_dim,
num_query_heads=num_query_heads,
quant=quant,
)


def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads):
def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads, quant):
"""Gets key and value from compressed KV latent vector and key rope."""
kv_out = dot(low_rank_main, wkv_b_weights)
kv_out = dot(low_rank_main, wkv_b_weights, quant=quant)

# Split kv_out into key_nope and value parts.
key_nope, value = jnp.split(kv_out, [qk_nope_head_dim], axis=-1)
Expand Down Expand Up @@ -686,6 +710,7 @@ def moe(
expert_axis_name,
use_gather_mosaic_kernel,
config,
quant,
):
"""Performs dropless MoE with tensor/expert parallelism."""
xs, ys = list(zip(*inputs))
Expand All @@ -700,6 +725,7 @@ def moe(
expert_axis_name=expert_axis_name,
use_gather_mosaic_kernel=use_gather_mosaic_kernel,
config=config,
quant=quant,
),
mesh,
)
Expand Down Expand Up @@ -730,9 +756,10 @@ def expert_selection(
num_experts,
num_experts_per_tok,
routed_scaling_factor,
quant,
):
"""Selects experts for each token and calculates group sizes for each expert."""
pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel))
pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel, quant=quant))
logits = pre_bias_logits + routing_bias

selected_experts, weights = expert_indices_and_weights(
Expand Down Expand Up @@ -946,6 +973,7 @@ def route_compute_unroute(
use_gather_mosaic_kernel,
config,
mesh,
quant,
):
"""Routes, processes, and unroutes activations."""
orig_shape = xs[0].shape
Expand All @@ -957,7 +985,9 @@ def route_compute_unroute(

def route_fn(inputs):
# Shared expert.
y = dot(jax.nn.silu(dot(inputs, shared_w0)) * dot(inputs, shared_w1), shared_wo)
y = dot(
jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant
)

inputs = jnp.reshape(inputs, (-1, inputs.shape[-1]))
selected_experts, weights, group_sizes = expert_selection(
Expand All @@ -967,6 +997,7 @@ def route_fn(inputs):
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
routed_scaling_factor=routed_scaling_factor,
quant=quant,
)
x, selected_experts, weights, group_sizes = route(
inputs,
Expand Down Expand Up @@ -1019,6 +1050,7 @@ def process_activations(
expert_axis_name,
use_gather_mosaic_kernel,
config,
quant,
):
"""Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights."""
activation_pspec = jax.sharding.PartitionSpec(
Expand All @@ -1043,6 +1075,7 @@ def process_activations(
use_gather_mosaic_kernel=use_gather_mosaic_kernel,
config=config,
mesh=mesh,
quant=quant,
),
mesh=mesh,
in_specs=(
Expand Down
Loading