Commit fdb421c
committed
NNX: AQT in MaxEngine + serve-mode reload + gpt3 prefill fix
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.
NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
"serve") through from_config / create_model /
get_nnx_create_model_fn / create_nnx_abstract_model /
from_pretrained. Default "train" preserves existing callers; "serve"
propagates to configure_quantization so AQT layers don't materialize
the full-precision kernel when the on-disk checkpoint already
carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
config.checkpoint_is_quantized; _load_params_nnx drops its
NotImplementedError. Two paths: pre-quantized
(checkpoint_is_quantized=True) loads via quant_mode_str="serve";
full-precision + quantization=int8 loads in TRAIN mode and AQT
layers quantize per-forward (same numerical result for absmax
calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
convert path. Loads full-precision in TRAIN mode, transfers kernels
into a CONVERT-mode model, runs forward to populate qrhs.frozen via
the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
saves serve-mode-shaped state.
Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
parallel-tree of replicated NamedSharding leaves when a Variable's
value is a composite pytree (AQT serve-mode QTensor with a qvalue
int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
_unwrap_for_align use Variable.get_value() instead of v[...]
indexing for QTensor leaves (QTensor.__getitem__ trips on the
LogicallyPartitioned wrapper around qvalue). Widens the restore
filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
type. Skips QTensor leaves in the per-axis shape-alignment dispatch
(their saved shape already matches the model). _build_value_target
strips Partitioned wrappers around composite-leaf values so the
restore tree path matches the on-disk layout (LogicallyPartitioned
was adding an extra .value key under each QTensor leaf, which made
orbax silently fill the path with zero-init values).
gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
ever calling update_kv_caches to build cached_values, so any
non-TRAIN forward (prefill or autoregressive) tripped the
`assert prefill_kv_cache` check. Mirror the standard Attention
plumbing in attentions.py: __init__ constructs a KVCache_0 module
when model_mode != MODEL_MODE_TRAIN, threads
max_prefill_predict_length into AttentionOp; __call__ calls
self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
cached_values to attention_op. TRAIN-mode shape unchanged.
Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
_strip_kernels_at_quantized_paths covering quantized removal,
non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
builds a small NNX model in CONVERT mode with int8, runs a forward
to populate qrhs.frozen via the ToNNX bridge, saves the
serve-mode-shape state to a tmp local orbax checkpoint, reloads via
from_pretrained(quant_mode_str="serve"), and asserts every saved
qrhs.frozen.qvalue array byte-matches what came back. Guards the
full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
test_quantize_passes_gate_for_nnx; added
test_load_pre_quantized_nnx_passes_quant_gate and
test_quantized_prefill_nnx_train_mode (real numerical verification
with quantization=int8 + random params + TRAIN mode).
End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.1 parent edf5d3f commit fdb421c
8 files changed
Lines changed: 641 additions & 58 deletions
File tree
- src/maxtext
- inference/maxengine
- models
- utils
- tests
- integration
- unit
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
117 | 117 | | |
118 | 118 | | |
119 | 119 | | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
120 | 126 | | |
121 | 127 | | |
122 | 128 | | |
123 | | - | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
124 | 134 | | |
125 | | - | |
| 135 | + | |
126 | 136 | | |
127 | | - | |
| 137 | + | |
| 138 | + | |
128 | 139 | | |
129 | 140 | | |
130 | 141 | | |
| |||
370 | 381 | | |
371 | 382 | | |
372 | 383 | | |
373 | | - | |
374 | | - | |
375 | | - | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
376 | 393 | | |
377 | 394 | | |
378 | 395 | | |
| |||
401 | 418 | | |
402 | 419 | | |
403 | 420 | | |
404 | | - | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
405 | 425 | | |
406 | | - | |
407 | | - | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
408 | 458 | | |
409 | 459 | | |
410 | | - | |
| 460 | + | |
411 | 461 | | |
412 | 462 | | |
413 | 463 | | |
| |||
495 | 545 | | |
496 | 546 | | |
497 | 547 | | |
498 | | - | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
499 | 558 | | |
500 | 559 | | |
501 | 560 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
| 31 | + | |
31 | 32 | | |
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
37 | 38 | | |
38 | | - | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| |||
258 | 258 | | |
259 | 259 | | |
260 | 260 | | |
| 261 | + | |
261 | 262 | | |
262 | 263 | | |
263 | 264 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
| 51 | + | |
50 | 52 | | |
51 | 53 | | |
52 | 54 | | |
| |||
164 | 166 | | |
165 | 167 | | |
166 | 168 | | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
172 | 178 | | |
173 | 179 | | |
174 | 180 | | |
175 | 181 | | |
176 | | - | |
177 | | - | |
178 | 182 | | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
179 | 188 | | |
180 | 189 | | |
181 | 190 | | |
| |||
187 | 196 | | |
188 | 197 | | |
189 | 198 | | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
190 | 202 | | |
191 | 203 | | |
192 | 204 | | |
| |||
272 | 284 | | |
273 | 285 | | |
274 | 286 | | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
275 | 413 | | |
276 | 414 | | |
277 | 415 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1631 | 1631 | | |
1632 | 1632 | | |
1633 | 1633 | | |
1634 | | - | |
1635 | | - | |
| 1634 | + | |
| 1635 | + | |
| 1636 | + | |
| 1637 | + | |
| 1638 | + | |
| 1639 | + | |
| 1640 | + | |
| 1641 | + | |
| 1642 | + | |
| 1643 | + | |
| 1644 | + | |
1636 | 1645 | | |
1637 | 1646 | | |
1638 | 1647 | | |
| |||
0 commit comments