diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 388247b5a0..9dd3624584 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2529,6 +2529,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.use_grpo = True else: self.use_grpo = False + + if self.use_batch_split_schedule: + if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm): + raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`") + if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"): + raise ValueError( + "Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`" + ) + if self.opt_type == "muon" and self.decoder_block not in [ DecoderBlockType.DEEPSEEK, DecoderBlockType.QWEN3, @@ -2537,7 +2546,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ]: raise ValueError( "Muon dimension numbers haven't been tested for this model. Run this command first: " - f"`python3 -m MaxText.muon_utils {self.model_name} True`" + f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`" ) if self.force_q_layout and not self.use_jax_splash: raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.") diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index abc47529b7..503e7e0b04 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -17,7 +17,7 @@ import functools import json import re -from typing import Tuple, Sequence +from typing import Tuple, Sequence, Callable from dataclasses import dataclass from aqt.jax.v2 import config as aqt_config @@ -27,6 +27,7 @@ from aqt.jax.v2 import calibration import qwix +from qwix._src.core import dot_general_qt import jax import jax.numpy as jnp @@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()): return aqt_einsum +@dataclass +class QwixQuantization: + """Configures Qwix quantization github.com/google/qwix, for training only.""" + + quant_mode = "train" # needed by external call + act_calibration_method: str = "absmax" + weight_calibration_method: str = "absmax" + bwd_calibration_method: str = "absmax" + + def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig: + """Returns Qwix dot_general config for fp8_full quantization.""" + return dot_general_qt.DotGeneralQtConfig( + lhs_qtype=jnp.float8_e4m3fn, # activation + rhs_qtype=jnp.float8_e4m3fn, # weight + dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient + drhs_grad_qtype=jnp.float8_e5m2, # weight gradient + lhs_calibration_method=self.act_calibration_method, + rhs_calibration_method=self.weight_calibration_method, + dlhs_grad_calibration_method=self.bwd_calibration_method, + drhs_grad_calibration_method=self.bwd_calibration_method, + tile_size=None, + ) + + def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): + """Returns Qwix dot_general.""" + return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config()) + + def einsum(self, mesh_axes: Tuple[str, ...] = ()): + """Returns Qwix einsum.""" + return QwixEinsum(config=self._get_fp8_full_qwix_config()) + + +class QwixDotGeneral(nn.Module): + """A callable class for Qwix dot_general.""" + + config: dot_general_qt.DotGeneralQtConfig + + @nn.compact + def __call__( + self, + lhs: jax.Array, + rhs: jax.Array, + dimension_numbers: jax.lax.DotDimensionNumbers, + precision: jax.lax.PrecisionLike = None, + preferred_element_type: jax.typing.DTypeLike | None = None, + *, + out_sharding=None, + ) -> jax.Array: + + return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config) + + +class QwixEinsum(nn.Module): + """A callable class for Qwix einsum.""" + + config: dot_general_qt.DotGeneralQtConfig + + @nn.compact + def __call__( + self, + einsum_str: str, + *operands: jax.Array, + precision: jax.lax.PrecisionLike = None, + preferred_element_type: jax.typing.DTypeLike | None = None, + _dot_general: Callable[..., jax.Array] | None = None, + out_sharding=None, + ) -> jax.Array: + + def custom_dot_general(*args, **kwargs): + return dot_general_qt.dot_general_qt(*args[:3], self.config) + + with jax.disable_jit(): + return jnp.einsum( + einsum_str, + *operands, + precision=precision, + preferred_element_type=preferred_element_type, + _dot_general=custom_dot_general, + out_sharding=out_sharding, + ) + + @dataclass class Fp8Quantization(Quantization): """Configures Fp8 quantization for NVIDIA GPUs""" @@ -539,13 +622,20 @@ def get_quant_mode(quant_mode_str: str = "train"): return aqt_flax.QuantMode.SERVE elif quant_mode_str == "convert": return aqt_flax.QuantMode.CONVERT - else: - raise ValueError(f"Invalid quantization mode {quant_mode_str}.") - return None + raise ValueError(f"Invalid quantization mode {quant_mode_str}.") def configure_quantization(config: Config, quant_mode_str: str = "train"): """Configure quantization based on user config and quant mode.""" + if config.use_batch_split_schedule and config.quantization: + if not (config.use_qwix_quantization and config.quantization == "fp8_full"): + raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`") + return QwixQuantization( + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, + ) + if config.use_qwix_quantization: return None quant_cfg = _get_quant_config(config) @@ -726,7 +816,8 @@ def get_qt_provider(config): def maybe_quantize_model(model, config): """Quantize the model if quantization is enabled.""" - if config.use_qwix_quantization: + # Batch split is not using Qwix's interception feature but manual plumbing + if config.use_qwix_quantization and not config.use_batch_split_schedule: quantization_provider = get_qt_provider(config) if quantization_provider: model = qwix.quantize_model(model, quantization_provider) diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index fabf0bacad..066b8cabbc 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -285,6 +285,7 @@ def batch_split_schedule( rope_factor=cfg.rope_factor, mscale=cfg.mscale, dtype=cfg.dtype, + quant=quant, ) xs = moe( @@ -297,6 +298,7 @@ def batch_split_schedule( expert_axis_name="expert", use_gather_mosaic_kernel=False, config=cfg, + quant=quant, ) return xs @@ -319,7 +321,21 @@ def with_data_parallel_constraint(x, mesh): return jax.lax.with_sharding_constraint(x, jax.NamedSharding(mesh, activation_pspec)) -def dot(x, y, axes=1): +def dot(x, y, quant=None, axes=1): + """Computes the dot product of two arrays, optionally using quantization.""" + if quant is not None: + # Convert axes to jax.lax.dot_general dimension_numbers + if isinstance(axes, int): + x_contract = tuple(range(x.ndim - axes, x.ndim)) + y_contract = tuple(range(axes)) + else: + x_contract, y_contract = axes + dimension_numbers = ((x_contract, y_contract), ((), ())) + # Instantiate and call qwix dot_general + custom_dot = quant.dot_general_cls()() + return custom_dot(lhs=x, rhs=y, dimension_numbers=dimension_numbers) + + # Unquantized return jnp.tensordot(x, y, axes=axes) @@ -345,6 +361,7 @@ def mla_with_norms( rope_factor, mscale, dtype, + quant, ): """Performs MLA with pre- and post-normalization.""" (pre_attn_scale, post_attn_scale), attn_ws = weights @@ -379,6 +396,7 @@ def fn(args): dtype=dtype, mscale=mscale, attention_op_fn=attn_op, + quant=quant, ), mesh, ) @@ -414,6 +432,7 @@ def mla( mscale, attention_op_fn, dtype, + quant, ): """Performs MLA.""" ( @@ -442,6 +461,7 @@ def mla( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, mscale=mscale, + quant=quant, ) query = jax.ad_checkpoint.checkpoint_name(query, "query_proj") key, value = kv_projection( @@ -462,6 +482,7 @@ def mla( dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, + quant=quant, ) key = jax.ad_checkpoint.checkpoint_name(key, "key_proj") value = jax.ad_checkpoint.checkpoint_name(value, "value_proj") @@ -474,7 +495,7 @@ def mla( cached_values=[None, None], ) out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") - out = dot(out, out_weights, axes=2) + out = dot(out, out_weights, quant=quant, axes=2) out = jax.ad_checkpoint.checkpoint_name(out, "out_proj") return out @@ -497,6 +518,7 @@ def query_projection( rope_factor, dtype, mscale, + quant, ): """Performs query projection.""" # Set softmax scaling. @@ -507,7 +529,7 @@ def query_projection( softmax_scale = softmax_scale * m * m # LoRA path - low_rank_q = dot(inputs_q, wq_a_weights) + low_rank_q = dot(inputs_q, wq_a_weights, quant=quant) low_rank_q = rms_norm( low_rank_q, q_norm_scale_weights, @@ -515,7 +537,7 @@ def query_projection( dtype=dtype, ) low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q") - q = dot(low_rank_q, wq_b_weights) + q = dot(low_rank_q, wq_b_weights, quant=quant) # Split into non-positional and rotary parts. q_nope, q_pe = jnp.split(q, [qk_nope_head_dim], axis=-1) @@ -554,9 +576,10 @@ def kv_projection( dtype, qk_nope_head_dim, num_query_heads, + quant, ): """Performs KV projection.""" - low_rank = dot(inputs, wkv_a_weights) + low_rank = dot(inputs, wkv_a_weights, quant=quant) low_rank_main, low_rank_rope = jnp.split(low_rank, [kv_lora_rank], axis=-1) low_rank_main = rms_norm( low_rank_main, @@ -585,12 +608,13 @@ def kv_projection( wkv_b_weights, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, + quant=quant, ) -def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads): +def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads, quant): """Gets key and value from compressed KV latent vector and key rope.""" - kv_out = dot(low_rank_main, wkv_b_weights) + kv_out = dot(low_rank_main, wkv_b_weights, quant=quant) # Split kv_out into key_nope and value parts. key_nope, value = jnp.split(kv_out, [qk_nope_head_dim], axis=-1) @@ -686,6 +710,7 @@ def moe( expert_axis_name, use_gather_mosaic_kernel, config, + quant, ): """Performs dropless MoE with tensor/expert parallelism.""" xs, ys = list(zip(*inputs)) @@ -700,6 +725,7 @@ def moe( expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, config=config, + quant=quant, ), mesh, ) @@ -730,9 +756,10 @@ def expert_selection( num_experts, num_experts_per_tok, routed_scaling_factor, + quant, ): """Selects experts for each token and calculates group sizes for each expert.""" - pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel)) + pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel, quant=quant)) logits = pre_bias_logits + routing_bias selected_experts, weights = expert_indices_and_weights( @@ -946,6 +973,7 @@ def route_compute_unroute( use_gather_mosaic_kernel, config, mesh, + quant, ): """Routes, processes, and unroutes activations.""" orig_shape = xs[0].shape @@ -957,7 +985,9 @@ def route_compute_unroute( def route_fn(inputs): # Shared expert. - y = dot(jax.nn.silu(dot(inputs, shared_w0)) * dot(inputs, shared_w1), shared_wo) + y = dot( + jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant + ) inputs = jnp.reshape(inputs, (-1, inputs.shape[-1])) selected_experts, weights, group_sizes = expert_selection( @@ -967,6 +997,7 @@ def route_fn(inputs): num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, routed_scaling_factor=routed_scaling_factor, + quant=quant, ) x, selected_experts, weights, group_sizes = route( inputs, @@ -1019,6 +1050,7 @@ def process_activations( expert_axis_name, use_gather_mosaic_kernel, config, + quant, ): """Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights.""" activation_pspec = jax.sharding.PartitionSpec( @@ -1043,6 +1075,7 @@ def process_activations( use_gather_mosaic_kernel=use_gather_mosaic_kernel, config=config, mesh=mesh, + quant=quant, ), mesh=mesh, in_specs=(