Skip to content

Commit fdb421c

Browse files
committed
NNX: AQT in MaxEngine + serve-mode reload + gpt3 prefill fix
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both load pre-quantized checkpoints directly and convert full-precision checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill bug surfaced by the AQT end-to-end validation. NNX + AQT in MaxEngine: - model_creation_utils threads quant_mode_str ("train" | "convert" | "serve") through from_config / create_model / get_nnx_create_model_fn / create_nnx_abstract_model / from_pretrained. Default "train" preserves existing callers; "serve" propagates to configure_quantization so AQT layers don't materialize the full-precision kernel when the on-disk checkpoint already carries qrhs scale factors. - maxengine.__init__ selects the quant mode from config.checkpoint_is_quantized; _load_params_nnx drops its NotImplementedError. Two paths: pre-quantized (checkpoint_is_quantized=True) loads via quant_mode_str="serve"; full-precision + quantization=int8 loads in TRAIN mode and AQT layers quantize per-forward (same numerical result for absmax calibration). - layerwise_quantization._load_and_quantize_nnx: whole-model NNX convert path. Loads full-precision in TRAIN mode, transfers kernels into a CONVERT-mode model, runs forward to populate qrhs.frozen via the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths, saves serve-mode-shaped state. Sharding helpers and from_pretrained QTensor handling (5 chained fixes that kept the serve-mode reload from working): - maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a parallel-tree of replicated NamedSharding leaves when a Variable's value is a composite pytree (AQT serve-mode QTensor with a qvalue int8 leaf and a list of bf16 scale leaves). - model_creation_utils.from_pretrained: drops a redundant jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode AQT under Flax 0.12.6. _build_value_target / _free_device_memory / _unwrap_for_align use Variable.get_value() instead of v[...] indexing for QTensor leaves (QTensor.__getitem__ trips on the LogicallyPartitioned wrapper around qvalue). Widens the restore filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable type. Skips QTensor leaves in the per-axis shape-alignment dispatch (their saved shape already matches the model). _build_value_target strips Partitioned wrappers around composite-leaf values so the restore tree path matches the on-disk layout (LogicallyPartitioned was adding an extra .value key under each QTensor leaf, which made orbax silently fill the path with zero-init values). gpt3 prefill / autoregressive fix (pre-existing, surfaced here): - Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without ever calling update_kv_caches to build cached_values, so any non-TRAIN forward (prefill or autoregressive) tripped the `assert prefill_kv_cache` check. Mirror the standard Attention plumbing in attentions.py: __init__ constructs a KVCache_0 module when model_mode != MODEL_MODE_TRAIN, threads max_prefill_predict_length into AttentionOp; __call__ calls self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as cached_values to attention_op. TRAIN-mode shape unchanged. Tests: - layerwise_quantization_nnx_test (new): 3 cases for _strip_kernels_at_quantized_paths covering quantized removal, non-quantized preservation (norms, embeddings), mixed-shape trees. - aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that builds a small NNX model in CONVERT mode with int8, runs a forward to populate qrhs.frozen via the ToNNX bridge, saves the serve-mode-shape state to a tmp local orbax checkpoint, reloads via from_pretrained(quant_mode_str="serve"), and asserts every saved qrhs.frozen.qvalue array byte-matches what came back. Guards the full chain of QTensor / Partitioned / filter fixes. - maxengine_test: replaced test_quantize_raises_for_nnx with test_quantize_passes_gate_for_nnx; added test_load_pre_quantized_nnx_passes_quant_gate and test_quantized_prefill_nnx_train_mode (real numerical verification with quantization=int8 + random params + TRAIN mode). End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen extraction + serve-mode-shape save + reload via from_pretrained(quant_mode_str="serve") + maxengine.load_params + quantized prefill forward all work; loaded qrhs.frozen.qvalue byte-matches the on-disk state.
1 parent edf5d3f commit fdb421c

8 files changed

Lines changed: 641 additions & 58 deletions

File tree

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,25 @@ def __init__(self, config: Any, devices: Any | None = None):
117117
# Model and Optimizer definition.
118118
quant = quantizations.configure_quantization(config)
119119
if config.pure_nnx:
120+
# `serve` only when the on-disk checkpoint already carries `qrhs.frozen`
121+
# (no full-precision kernel). For `checkpoint_is_quantized=False` with
122+
# quant enabled we stay in `train` mode and let AQT quantize per-forward
123+
# against the full-precision kernel — same numerical result as `serve`
124+
# for absmax calibration, just slower.
125+
nnx_quant_mode_str = "serve" if (quant is not None and config.checkpoint_is_quantized) else "train"
120126
# We need both PREFILL and AR abstract models because the cache vars inherit
121127
# CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and
122128
# bulk_insert searches for the substring "cache_batch" in the AR-mode names.
123-
_create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL)
129+
# Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids
130+
# the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm".
131+
_create_model = model_creation_utils.get_nnx_create_model_fn(
132+
config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL, quant_mode_str=nnx_quant_mode_str
133+
)
124134
_create_model_ar = model_creation_utils.get_nnx_create_model_fn(
125-
config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE
135+
config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE, quant_mode_str=nnx_quant_mode_str
126136
)
127-
with jax.set_mesh(self._mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
137+
self._nnx_quant_mode_str = nnx_quant_mode_str
138+
with nn_partitioning.axis_rules(config.logical_axis_rules):
128139
abstract_model = nnx.eval_shape(_create_model)
129140
abstract_model_ar = nnx.eval_shape(_create_model_ar)
130141
self.model = abstract_model
@@ -370,9 +381,15 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
370381
return params
371382

372383
def _load_params_nnx(self, params, rng):
373-
"""NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings."""
374-
if self.model.quant is not None:
375-
raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.")
384+
"""NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.
385+
386+
Quantization handling:
387+
* `checkpoint_is_quantized=True`: model built in `serve` mode (no full
388+
kernel), `from_pretrained` reads `qrhs.frozen` from disk.
389+
* `checkpoint_is_quantized=False` + `quantization=...`: model built in
390+
`train` mode, full-precision kernel loaded; AQT layers quantize per
391+
forward. Same output as serve mode (absmax calibration), slower.
392+
"""
376393

377394
if params:
378395
print("Resharding given NNX params")
@@ -401,13 +418,46 @@ def _load_params_nnx(self, params, rng):
401418
max_logging.log("Loading NNX params via from_pretrained")
402419
with self._mesh:
403420
nnx_model = model_creation_utils.from_pretrained(
404-
self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE
421+
self.config,
422+
mesh=self._mesh,
423+
model_mode=MODEL_MODE_AUTOREGRESSIVE,
424+
quant_mode_str=self._nnx_quant_mode_str,
405425
)
406-
# Refresh graphdef from the concrete loaded model so subsequent merges line up.
407-
graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...)
426+
# 4-way split keeps the loaded AQT `qrhs.frozen` leaves (and any other
427+
# non-Param/non-Cache vars) in `loaded_rest_state` so they survive into
428+
# `_nnx_rest_state`. Param-only filtering would silently drop them and
429+
# the model would run with random qrhs values.
430+
_, params_state, _, loaded_rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...)
431+
# `_prefill_jit` re-merges with `self.graphdef`, which must be the PREFILL
432+
# graphdef built in `__init__` (matching `_create_model_fn`). Don't
433+
# overwrite with the AR-mode graphdef from `from_pretrained` — the
434+
# PREFILL/AR attention ops have different cache variable shapes, and a
435+
# mismatch trips the `assert prefill_kv_cache` check inside attention_op.
436+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
437+
concrete_model = self._create_model_fn()
438+
graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...)
439+
# Overlay loaded non-Param/non-Cache leaves (e.g. AQT qrhs.frozen) onto
440+
# the PREFILL-mode rest_state. The PREFILL concrete_model already has
441+
# placeholder qrhs vars at the right paths; we just swap in the loaded
442+
# values. Anything only in `loaded_rest_state` (e.g. AR-only RNG slots)
443+
# is ignored. We keep PREFILL rest_state as the base so RNG variables
444+
# match the PREFILL graphdef's expectations.
445+
loaded_rest_dict = loaded_rest_state.to_pure_dict()
446+
rest_dict = rest_state.to_pure_dict()
447+
448+
def _overlay(dst, src):
449+
if isinstance(dst, dict):
450+
for k, v in dst.items():
451+
if k in src:
452+
dst[k] = _overlay(v, src[k])
453+
return dst
454+
return src if not isinstance(src, dict) else dst
455+
456+
rest_dict = _overlay(rest_dict, loaded_rest_dict)
457+
nnx.replace_by_pure_dict(rest_state, rest_dict)
408458
self.graphdef = graphdef
409459
self._nnx_rest_state = rest_state
410-
del nnx_model
460+
del nnx_model, concrete_model
411461

412462
self.abstract_params = jax.tree.map(
413463
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
@@ -495,7 +545,16 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None):
495545
if rng is None:
496546
rng = jax.random.PRNGKey(0)
497547
if self.config.pure_nnx:
498-
raise NotImplementedError("pure_nnx + quantize_params not yet supported.")
548+
# NNX takes a different code path: convert-on-load lives in `_load_params_nnx`
549+
# via `_convert_and_quantize_nnx`, which runs the dummy forward against a
550+
# CONVERT-mode model and transfers `qrhs.frozen` into the SERVE model.
551+
# The standalone `quantize_params(state, rng)` API expects a Linen-shape
552+
# `state.params` dict and isn't reachable on the NNX pathway in maxengine
553+
# (load_params already dispatched to _load_params_nnx).
554+
raise NotImplementedError(
555+
"Use load_params() on NNX — the convert step runs inside _load_params_nnx via "
556+
"_convert_and_quantize_nnx. quantize_params(state, rng) is the Linen API."
557+
)
499558

500559
self.model.quant.quant_mode = quantizations.get_quant_mode("convert")
501560

src/maxtext/models/gpt3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@
2828
from flax import nnx
2929

3030
from maxtext.common.common_types import Config, DType, AxisNames, BATCH, LENGTH, EMBED, HEAD, D_KV, Array, MODEL_MODE_TRAIN
31+
from maxtext.inference import kvcache
3132
from maxtext.layers import initializers, nnx_wrappers
3233
from maxtext.layers.linears import DenseGeneral, MlpBlock, canonicalize_tuple, normalize_axes
3334
from maxtext.layers import quantizations
3435
from maxtext.layers import linears
3536
from maxtext.layers.attentions import AttentionOp, KVQuant
3637
from maxtext.layers.initializers import Initializer, NdInitializer, nd_dense_init
3738
from maxtext.layers.quantizations import AqtQuantization as Quant
38-
from maxtext.inference import kvcache
3939
from maxtext.utils import max_logging
4040
from maxtext.utils import max_utils
4141

@@ -258,6 +258,7 @@ def __init__(
258258
mesh=self.mesh,
259259
attention_kernel=self.attention_kernel,
260260
max_target_length=self.max_target_length,
261+
max_prefill_predict_length=self.max_prefill_predict_length,
261262
float32_qk_product=self.float32_qk_product,
262263
float32_logits=self.float32_logits,
263264
quant=self.quant,

src/maxtext/utils/layerwise_quantization.py

Lines changed: 145 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from maxtext.utils import max_logging
4848
from maxtext.utils import max_utils
4949
from maxtext.utils import maxtext_utils
50+
from maxtext.utils import maxtext_utils_nnx
51+
from maxtext.utils import model_creation_utils
5052
import orbax.checkpoint as ocp
5153
from tqdm import tqdm
5254
from maxtext.configs import pyconfig
@@ -164,18 +166,25 @@ def __init__(self, config: Any, rng: PRNGKeyType):
164166
self.config = config
165167
self.rng = rng
166168

167-
# TODO(ranlihao): Remove this assertion once the Layerwise quantization is supported for other decoder blocks.
168-
assert (
169-
config.decoder_block == common_types.DecoderBlockType.DEEPSEEK
170-
), f"Layerwise quantization is only supported for {common_types.DecoderBlockType.DEEPSEEK}\
171-
, but got {config.decoder_block}."
169+
# The Linen path runs layer-by-layer (memory-efficient for big DeepSeek
170+
# models) and is DeepSeek-specific because it relies on the per-layer
171+
# `DeepSeek*ToLinen` wrappers. The NNX path runs whole-model convert
172+
# forward and is model-agnostic — see `_load_and_quantize_nnx`.
173+
if not config.pure_nnx:
174+
assert config.decoder_block == common_types.DecoderBlockType.DEEPSEEK, (
175+
f"Linen layerwise quantization only supports {common_types.DecoderBlockType.DEEPSEEK}, "
176+
f"got {config.decoder_block}."
177+
)
172178
# Mesh definition
173179
devices_array = maxtext_utils.create_device_mesh(config=config)
174180
self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
175181

176-
# Input and output are both Linen-format (uses DeepSeek*ToLinen layers below).
177-
# Route to Linen regardless of pure_nnx.
178182
self.quant = quantizations.configure_quantization(config)
183+
if config.pure_nnx:
184+
# NNX takes a separate code path that builds the model via from_pretrained;
185+
# no Linen abstract-state bookkeeping is needed here.
186+
self.unboxed_abstract_state = None
187+
return
179188
model = models.transformer_as_linen(
180189
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
181190
)
@@ -187,6 +196,9 @@ def load_and_quantize(self) -> None:
187196
"""
188197
Load parameters layer by layer and quantize them.
189198
"""
199+
if self.config.pure_nnx:
200+
self._load_and_quantize_nnx()
201+
return
190202
quantized_params = {}
191203
quantized_params["params"] = {"decoder": {}}
192204
quantized_params["aqt"] = {"decoder": {}}
@@ -272,6 +284,132 @@ def model_apply(_p, _rng, layer):
272284

273285
maxtext_utils.save_quantized_checkpoint_if_configured(self.config, quantized_params)
274286

287+
def _load_and_quantize_nnx(self) -> None:
288+
"""Whole-model NNX convert: load full-precision via TRAIN-mode `from_pretrained`,
289+
transfer kernels into a fresh CONVERT-mode model, run a forward (the
290+
`ToNNX(AqtDotGeneral)` bridge auto-captures `qrhs.frozen`), strip kernels at
291+
quantized paths, and save the serve-mode-shaped state.
292+
293+
Two-step load: input checkpoints are typically full-precision (no AQT state
294+
on disk), so we can't `from_pretrained(quant_mode_str="convert")` directly —
295+
orbax would fail to find the missing `qrhs.frozen` leaves. Instead we load
296+
in TRAIN mode (which has only kernels), then copy them into a randomly
297+
initialized CONVERT model that already has the AQT variables provisioned.
298+
"""
299+
config = self.config
300+
# MODEL_MODE_TRAIN avoids the PREFILL/AUTOREGRESSIVE cache plumbing — AQT
301+
# layers populate `qrhs.frozen` regardless of model_mode, so train mode is
302+
# simpler and faster.
303+
max_logging.log("Loading full-precision NNX checkpoint in TRAIN mode...")
304+
with self._mesh:
305+
train_model = model_creation_utils.from_pretrained(
306+
config,
307+
mesh=self._mesh,
308+
model_mode=common_types.MODEL_MODE_TRAIN,
309+
quant_mode_str="train",
310+
)
311+
312+
max_logging.log("Building CONVERT-mode model (random init) and copying kernels in...")
313+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=self.rng)
314+
with nn_partitioning.axis_rules(config.logical_axis_rules):
315+
convert_model = model_creation_utils.from_config(
316+
config,
317+
mesh=self._mesh,
318+
rngs=rngs,
319+
model_mode=common_types.MODEL_MODE_TRAIN,
320+
quant_mode_str="convert",
321+
)
322+
self._copy_kernel_leaves_(convert_model, train_model)
323+
del train_model
324+
325+
# Forward populates AqtDotGeneral_0.qrhs.frozen on every quantized layer.
326+
L = config.max_target_length
327+
decoder_input_tokens = jnp.zeros((1, L), dtype=jnp.int32)
328+
decoder_positions = jnp.arange(L, dtype=jnp.int32)[None, :]
329+
decoder_segment_ids = jnp.ones((1, L), dtype=jnp.int32)
330+
max_logging.log("Running CONVERT-mode forward to populate AQT scale factors...")
331+
with nn_partitioning.axis_rules(config.logical_axis_rules):
332+
_ = convert_model(
333+
decoder_input_tokens,
334+
decoder_positions,
335+
decoder_segment_ids=decoder_segment_ids,
336+
enable_dropout=False,
337+
model_mode=common_types.MODEL_MODE_TRAIN,
338+
)
339+
340+
# Convert-mode state has both `kernel` (full precision) and `AqtDotGeneral_0.qrhs.frozen`
341+
# at every quantized DenseGeneral; the serve-mode reader expects only the latter.
342+
convert_state = nnx.state(convert_model).to_pure_dict()
343+
serve_state = self._strip_kernels_at_quantized_paths(convert_state)
344+
345+
if config.save_quantized_params_path:
346+
max_logging.log(f"Saving NNX-format quantized checkpoint to {config.save_quantized_params_path}")
347+
348+
# Wrap each leaf in `{"value": <array>}` so the on-disk shape matches what
349+
# `from_pretrained`'s NNX-detection branch reads back (it later does
350+
# `tree.map(lambda v: v["value"], ...)` on each leaf). Save directly via
351+
# orbax — `save_params_to_path` would add an outer `{"params": ...}` wrap
352+
# that the NNX path doesn't expect.
353+
def _wrap_value(node):
354+
if isinstance(node, dict):
355+
return {k: _wrap_value(v) for k, v in node.items()}
356+
return {"value": node}
357+
358+
wrapped = _wrap_value(serve_state)
359+
orbax_checkpointer = ocp.PyTreeCheckpointer(
360+
use_ocdbt=config.checkpoint_storage_use_ocdbt,
361+
use_zarr3=config.checkpoint_storage_use_zarr3,
362+
)
363+
orbax_checkpointer.save(config.save_quantized_params_path, wrapped, force=True)
364+
max_logging.log(f"Saved NNX-format quantized checkpoint at: {config.save_quantized_params_path}")
365+
else:
366+
max_logging.log("Skipping save: save_quantized_params_path is null.")
367+
368+
@staticmethod
369+
def _copy_kernel_leaves_(dst_model, src_model):
370+
"""Copy the full-precision parameter leaves (kernel/embedding/scale/bias)
371+
from src into dst, leaving dst's AQT and RNG variables untouched.
372+
"""
373+
src_dict = nnx.state(src_model).to_pure_dict()
374+
dst_state = nnx.state(dst_model)
375+
dst_dict = dst_state.to_pure_dict()
376+
377+
def walk(d_node, s_node):
378+
if not (isinstance(d_node, dict) and isinstance(s_node, dict)):
379+
return
380+
for key, d_child in d_node.items():
381+
if key not in s_node:
382+
continue
383+
s_child = s_node[key]
384+
if key in ("kernel", "embedding", "scale", "bias") and not isinstance(d_child, dict):
385+
d_node[key] = s_child
386+
elif isinstance(d_child, dict):
387+
walk(d_child, s_child)
388+
389+
walk(dst_dict, src_dict)
390+
nnx.replace_by_pure_dict(dst_state, dst_dict)
391+
nnx.update(dst_model, dst_state)
392+
393+
@staticmethod
394+
def _strip_kernels_at_quantized_paths(state_dict):
395+
"""Drop `kernel` keys at any node that has a sibling `AqtDotGeneral_0`.
396+
397+
In convert mode each quantized DenseGeneral keeps both the full-precision
398+
`kernel` (an nnx.Param) and the AQT-quantized `AqtDotGeneral_0.qrhs.frozen`
399+
side-by-side. Serve mode (the on-disk shape `from_pretrained` reads back)
400+
only carries the latter; the kernel is recreated as a dummy zero in
401+
`linears.DenseGeneral.__call__`.
402+
"""
403+
if not isinstance(state_dict, dict):
404+
return state_dict
405+
has_aqt = "AqtDotGeneral_0" in state_dict
406+
out = {}
407+
for k, v in state_dict.items():
408+
if k == "kernel" and has_aqt:
409+
continue
410+
out[k] = LayerwiseQuantization._strip_kernels_at_quantized_paths(v) if isinstance(v, dict) else v
411+
return out
412+
275413
def _load_layer(self, layer_name):
276414
"""Loads a specific layer's parameters from the checkpoint."""
277415

src/maxtext/utils/maxtext_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,8 +1631,17 @@ def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx
16311631
def _make_named_sharding(v):
16321632
val = v.get_value()
16331633
if not hasattr(val, "shape"):
1634-
# Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve
1635-
# as-is so the treedef matches abs_var_state in the downstream jax.tree.map.
1634+
# `val` is either truly leafless (e.g. optax MaskedNode) or a composite
1635+
# pytree of tensors (e.g. AQT QTensor on serve-mode quantized variables —
1636+
# a `qvalue` int8 array + a list of `scale` bf16 arrays). For the latter
1637+
# we must emit a parallel tree of NamedSharding leaves so the downstream
1638+
# `jax.tree.map(lambda a, s: ShapeDtypeStruct(..., sharding=s), abs, names)`
1639+
# finds a real Sharding at every position. Replicated sharding is a safe
1640+
# default — AQT serve-mode QTensors are normally small (per-channel scale
1641+
# factors and packed int8 weights) and don't need axis-aware sharding.
1642+
if jax.tree_util.tree_leaves(val):
1643+
replicated = NamedSharding(mesh, PartitionSpec())
1644+
return v.replace(jax.tree.map(lambda _: replicated, val))
16361645
return v
16371646
metadata = v.get_metadata()
16381647
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")

0 commit comments

Comments
 (0)