Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,25 @@ def __init__(self, config: Any, devices: Any | None = None):
# Model and Optimizer definition.
quant = quantizations.configure_quantization(config)
if config.pure_nnx:
# `serve` only when the on-disk checkpoint already carries `qrhs.frozen`
# (no full-precision kernel). For `checkpoint_is_quantized=False` with
# quant enabled we stay in `train` mode and let AQT quantize per-forward
# against the full-precision kernel — same numerical result as `serve`
# for absmax calibration, just slower.
nnx_quant_mode_str = "serve" if (quant is not None and config.checkpoint_is_quantized) else "train"
# We need both PREFILL and AR abstract models because the cache vars inherit
# CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and
# bulk_insert searches for the substring "cache_batch" in the AR-mode names.
_create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL)
# Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids
# the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm".
_create_model = model_creation_utils.get_nnx_create_model_fn(
config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL, quant_mode_str=nnx_quant_mode_str
)
_create_model_ar = model_creation_utils.get_nnx_create_model_fn(
config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE
config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE, quant_mode_str=nnx_quant_mode_str
)
with jax.set_mesh(self._mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
self._nnx_quant_mode_str = nnx_quant_mode_str
with nn_partitioning.axis_rules(config.logical_axis_rules):
abstract_model = nnx.eval_shape(_create_model)
abstract_model_ar = nnx.eval_shape(_create_model_ar)
self.model = abstract_model
Expand Down Expand Up @@ -370,9 +381,15 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
return params

def _load_params_nnx(self, params, rng):
"""NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings."""
if self.model.quant is not None:
raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.")
"""NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.

Quantization handling:
* `checkpoint_is_quantized=True`: model built in `serve` mode (no full
kernel), `from_pretrained` reads `qrhs.frozen` from disk.
* `checkpoint_is_quantized=False` + `quantization=...`: model built in
`train` mode, full-precision kernel loaded; AQT layers quantize per
forward. Same output as serve mode (absmax calibration), slower.
"""

if params:
print("Resharding given NNX params")
Expand Down Expand Up @@ -401,13 +418,46 @@ def _load_params_nnx(self, params, rng):
max_logging.log("Loading NNX params via from_pretrained")
with self._mesh:
nnx_model = model_creation_utils.from_pretrained(
self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE
self.config,
mesh=self._mesh,
model_mode=MODEL_MODE_AUTOREGRESSIVE,
quant_mode_str=self._nnx_quant_mode_str,
)
# Refresh graphdef from the concrete loaded model so subsequent merges line up.
graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...)
# 4-way split keeps the loaded AQT `qrhs.frozen` leaves (and any other
# non-Param/non-Cache vars) in `loaded_rest_state` so they survive into
# `_nnx_rest_state`. Param-only filtering would silently drop them and
# the model would run with random qrhs values.
_, params_state, _, loaded_rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...)
# `_prefill_jit` re-merges with `self.graphdef`, which must be the PREFILL
# graphdef built in `__init__` (matching `_create_model_fn`). Don't
# overwrite with the AR-mode graphdef from `from_pretrained` — the
# PREFILL/AR attention ops have different cache variable shapes, and a
# mismatch trips the `assert prefill_kv_cache` check inside attention_op.
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
concrete_model = self._create_model_fn()
graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...)
# Overlay loaded non-Param/non-Cache leaves (e.g. AQT qrhs.frozen) onto
# the PREFILL-mode rest_state. The PREFILL concrete_model already has
# placeholder qrhs vars at the right paths; we just swap in the loaded
# values. Anything only in `loaded_rest_state` (e.g. AR-only RNG slots)
# is ignored. We keep PREFILL rest_state as the base so RNG variables
# match the PREFILL graphdef's expectations.
loaded_rest_dict = loaded_rest_state.to_pure_dict()
rest_dict = rest_state.to_pure_dict()

def _overlay(dst, src):
if isinstance(dst, dict):
for k, v in dst.items():
if k in src:
dst[k] = _overlay(v, src[k])
return dst
return src if not isinstance(src, dict) else dst

rest_dict = _overlay(rest_dict, loaded_rest_dict)
nnx.replace_by_pure_dict(rest_state, rest_dict)
self.graphdef = graphdef
self._nnx_rest_state = rest_state
del nnx_model
del nnx_model, concrete_model

self.abstract_params = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
Expand Down Expand Up @@ -495,7 +545,16 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None):
if rng is None:
rng = jax.random.PRNGKey(0)
if self.config.pure_nnx:
raise NotImplementedError("pure_nnx + quantize_params not yet supported.")
# NNX takes a different code path: convert-on-load lives in `_load_params_nnx`
# via `_convert_and_quantize_nnx`, which runs the dummy forward against a
# CONVERT-mode model and transfers `qrhs.frozen` into the SERVE model.
# The standalone `quantize_params(state, rng)` API expects a Linen-shape
# `state.params` dict and isn't reachable on the NNX pathway in maxengine
# (load_params already dispatched to _load_params_nnx).
raise NotImplementedError(
"Use load_params() on NNX — the convert step runs inside _load_params_nnx via "
"_convert_and_quantize_nnx. quantize_params(state, rng) is the Linen API."
)

self.model.quant.quant_mode = quantizations.get_quant_mode("convert")

Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/models/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from flax import nnx

from maxtext.common.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
from maxtext.inference import kvcache
from maxtext.layers import initializers, nnx_wrappers
from maxtext.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes
from maxtext.layers import quantizations
from maxtext.layers import linears
from maxtext.layers.attentions import AttentionOp, KVQuant
from maxtext.layers.initializers import Initializer, NdInitializer, nd_dense_init
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import kvcache
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down Expand Up @@ -258,6 +258,7 @@ def __init__(
mesh=self.mesh,
attention_kernel=self.attention_kernel,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
float32_qk_product=self.float32_qk_product,
float32_logits=self.float32_logits,
quant=self.quant,
Expand Down
152 changes: 145 additions & 7 deletions src/maxtext/utils/layerwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import maxtext_utils_nnx
from maxtext.utils import model_creation_utils
import orbax.checkpoint as ocp
from tqdm import tqdm
from maxtext.configs import pyconfig
Expand Down Expand Up @@ -164,18 +166,25 @@ def __init__(self, config: Any, rng: PRNGKeyType):
self.config = config
self.rng = rng

# TODO(ranlihao): Remove this assertion once the Layerwise quantization is supported for other decoder blocks.
assert (
config.decoder_block == common_types.DecoderBlockType.DEEPSEEK
), f"Layerwise quantization is only supported for {common_types.DecoderBlockType.DEEPSEEK}\
, but got {config.decoder_block}."
# The Linen path runs layer-by-layer (memory-efficient for big DeepSeek
# models) and is DeepSeek-specific because it relies on the per-layer
# `DeepSeek*ToLinen` wrappers. The NNX path runs whole-model convert
# forward and is model-agnostic — see `_load_and_quantize_nnx`.
if not config.pure_nnx:
assert config.decoder_block == common_types.DecoderBlockType.DEEPSEEK, (
f"Linen layerwise quantization only supports {common_types.DecoderBlockType.DEEPSEEK}, "
f"got {config.decoder_block}."
)
# Mesh definition
devices_array = maxtext_utils.create_device_mesh(config=config)
self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

# Input and output are both Linen-format (uses DeepSeek*ToLinen layers below).
# Route to Linen regardless of pure_nnx.
self.quant = quantizations.configure_quantization(config)
if config.pure_nnx:
# NNX takes a separate code path that builds the model via from_pretrained;
# no Linen abstract-state bookkeeping is needed here.
self.unboxed_abstract_state = None
return
model = models.transformer_as_linen(
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
)
Expand All @@ -187,6 +196,9 @@ def load_and_quantize(self) -> None:
"""
Load parameters layer by layer and quantize them.
"""
if self.config.pure_nnx:
self._load_and_quantize_nnx()
return
quantized_params = {}
quantized_params["params"] = {"decoder": {}}
quantized_params["aqt"] = {"decoder": {}}
Expand Down Expand Up @@ -272,6 +284,132 @@ def model_apply(_p, _rng, layer):

maxtext_utils.save_quantized_checkpoint_if_configured(self.config, quantized_params)

def _load_and_quantize_nnx(self) -> None:
"""Whole-model NNX convert: load full-precision via TRAIN-mode `from_pretrained`,
transfer kernels into a fresh CONVERT-mode model, run a forward (the
`ToNNX(AqtDotGeneral)` bridge auto-captures `qrhs.frozen`), strip kernels at
quantized paths, and save the serve-mode-shaped state.

Two-step load: input checkpoints are typically full-precision (no AQT state
on disk), so we can't `from_pretrained(quant_mode_str="convert")` directly —
orbax would fail to find the missing `qrhs.frozen` leaves. Instead we load
in TRAIN mode (which has only kernels), then copy them into a randomly
initialized CONVERT model that already has the AQT variables provisioned.
"""
config = self.config
# MODEL_MODE_TRAIN avoids the PREFILL/AUTOREGRESSIVE cache plumbing — AQT
# layers populate `qrhs.frozen` regardless of model_mode, so train mode is
# simpler and faster.
max_logging.log("Loading full-precision NNX checkpoint in TRAIN mode...")
with self._mesh:
train_model = model_creation_utils.from_pretrained(
config,
mesh=self._mesh,
model_mode=common_types.MODEL_MODE_TRAIN,
quant_mode_str="train",
)

max_logging.log("Building CONVERT-mode model (random init) and copying kernels in...")
rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=self.rng)
with nn_partitioning.axis_rules(config.logical_axis_rules):
convert_model = model_creation_utils.from_config(
config,
mesh=self._mesh,
rngs=rngs,
model_mode=common_types.MODEL_MODE_TRAIN,
quant_mode_str="convert",
)
self._copy_kernel_leaves_(convert_model, train_model)
del train_model

# Forward populates AqtDotGeneral_0.qrhs.frozen on every quantized layer.
L = config.max_target_length
decoder_input_tokens = jnp.zeros((1, L), dtype=jnp.int32)
decoder_positions = jnp.arange(L, dtype=jnp.int32)[None, :]
decoder_segment_ids = jnp.ones((1, L), dtype=jnp.int32)
max_logging.log("Running CONVERT-mode forward to populate AQT scale factors...")
with nn_partitioning.axis_rules(config.logical_axis_rules):
_ = convert_model(
decoder_input_tokens,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
enable_dropout=False,
model_mode=common_types.MODEL_MODE_TRAIN,
)

# Convert-mode state has both `kernel` (full precision) and `AqtDotGeneral_0.qrhs.frozen`
# at every quantized DenseGeneral; the serve-mode reader expects only the latter.
convert_state = nnx.state(convert_model).to_pure_dict()
serve_state = self._strip_kernels_at_quantized_paths(convert_state)

if config.save_quantized_params_path:
max_logging.log(f"Saving NNX-format quantized checkpoint to {config.save_quantized_params_path}")

# Wrap each leaf in `{"value": <array>}` so the on-disk shape matches what
# `from_pretrained`'s NNX-detection branch reads back (it later does
# `tree.map(lambda v: v["value"], ...)` on each leaf). Save directly via
# orbax — `save_params_to_path` would add an outer `{"params": ...}` wrap
# that the NNX path doesn't expect.
def _wrap_value(node):
if isinstance(node, dict):
return {k: _wrap_value(v) for k, v in node.items()}
return {"value": node}

wrapped = _wrap_value(serve_state)
orbax_checkpointer = ocp.PyTreeCheckpointer(
use_ocdbt=config.checkpoint_storage_use_ocdbt,
use_zarr3=config.checkpoint_storage_use_zarr3,
)
orbax_checkpointer.save(config.save_quantized_params_path, wrapped, force=True)
max_logging.log(f"Saved NNX-format quantized checkpoint at: {config.save_quantized_params_path}")
else:
max_logging.log("Skipping save: save_quantized_params_path is null.")

@staticmethod
def _copy_kernel_leaves_(dst_model, src_model):
"""Copy the full-precision parameter leaves (kernel/embedding/scale/bias)
from src into dst, leaving dst's AQT and RNG variables untouched.
"""
src_dict = nnx.state(src_model).to_pure_dict()
dst_state = nnx.state(dst_model)
dst_dict = dst_state.to_pure_dict()

def walk(d_node, s_node):
if not (isinstance(d_node, dict) and isinstance(s_node, dict)):
return
for key, d_child in d_node.items():
if key not in s_node:
continue
s_child = s_node[key]
if key in ("kernel", "embedding", "scale", "bias") and not isinstance(d_child, dict):
d_node[key] = s_child
elif isinstance(d_child, dict):
walk(d_child, s_child)

walk(dst_dict, src_dict)
nnx.replace_by_pure_dict(dst_state, dst_dict)
nnx.update(dst_model, dst_state)

@staticmethod
def _strip_kernels_at_quantized_paths(state_dict):
"""Drop `kernel` keys at any node that has a sibling `AqtDotGeneral_0`.

In convert mode each quantized DenseGeneral keeps both the full-precision
`kernel` (an nnx.Param) and the AQT-quantized `AqtDotGeneral_0.qrhs.frozen`
side-by-side. Serve mode (the on-disk shape `from_pretrained` reads back)
only carries the latter; the kernel is recreated as a dummy zero in
`linears.DenseGeneral.__call__`.
"""
if not isinstance(state_dict, dict):
return state_dict
has_aqt = "AqtDotGeneral_0" in state_dict
out = {}
for k, v in state_dict.items():
if k == "kernel" and has_aqt:
continue
out[k] = LayerwiseQuantization._strip_kernels_at_quantized_paths(v) if isinstance(v, dict) else v
return out

def _load_layer(self, layer_name):
"""Loads a specific layer's parameters from the checkpoint."""

Expand Down
13 changes: 11 additions & 2 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,8 +1631,17 @@ def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx
def _make_named_sharding(v):
val = v.get_value()
if not hasattr(val, "shape"):
# Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve
# as-is so the treedef matches abs_var_state in the downstream jax.tree.map.
# `val` is either truly leafless (e.g. optax MaskedNode) or a composite
# pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables —
# a `qvalue` int8 array + a list of `scale` bf16 arrays). For the latter
# we must emit a parallel tree of NamedSharding leaves so the downstream
# `jax.tree.map(lambda a, s: ShapeDtypeStruct(..., sharding=s), abs, names)`
# finds a real Sharding at every position. Replicated sharding is a safe
# default — AQT serve-mode QTensors are normally small (per-channel scale
# factors and packed int8 weights) and don't need axis-aware sharding.
if jax.tree_util.tree_leaves(val):
replicated = NamedSharding(mesh, PartitionSpec())
return v.replace(jax.tree.map(lambda _: replicated, val))
return v
metadata = v.get_metadata()
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")
Expand Down
Loading
Loading