Commit ca957ab
committed
NNX: AQT in MaxEngine + serve-mode reload + gpt3 prefill fix
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.
NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
"serve") through from_config / create_model /
get_nnx_create_model_fn / create_nnx_abstract_model /
from_pretrained. Default "train" preserves existing callers; "serve"
propagates to configure_quantization so AQT layers don't materialize
the full-precision kernel when the on-disk checkpoint already
carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
config.checkpoint_is_quantized; _load_params_nnx drops its
NotImplementedError. Two paths: pre-quantized
(checkpoint_is_quantized=True) loads via quant_mode_str="serve";
full-precision + quantization=int8 loads in TRAIN mode and AQT
layers quantize per-forward (same numerical result for absmax
calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
convert path. Loads full-precision in TRAIN mode, transfers kernels
into a CONVERT-mode model, runs forward to populate qrhs.frozen via
the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
saves serve-mode-shaped state.
Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
parallel-tree of replicated NamedSharding leaves when a Variable's
value is a composite pytree (AQT serve-mode QTensor with a qvalue
int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
_unwrap_for_align use Variable.get_value() instead of v[...]
indexing for QTensor leaves (QTensor.__getitem__ trips on the
LogicallyPartitioned wrapper around qvalue). Widens the restore
filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
type. Skips QTensor leaves in the per-axis shape-alignment dispatch
(their saved shape already matches the model). _build_value_target
strips Partitioned wrappers around composite-leaf values so the
restore tree path matches the on-disk layout (LogicallyPartitioned
was adding an extra .value key under each QTensor leaf, which made
orbax silently fill the path with zero-init values).
gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
ever calling update_kv_caches to build cached_values, so any
non-TRAIN forward (prefill or autoregressive) tripped the
`assert prefill_kv_cache` check. Mirror the standard Attention
plumbing in attentions.py: __init__ constructs a KVCache_0 module
when model_mode != MODEL_MODE_TRAIN, threads
max_prefill_predict_length into AttentionOp; __call__ calls
self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
cached_values to attention_op. TRAIN-mode shape unchanged.
Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
_strip_kernels_at_quantized_paths covering quantized removal,
non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
builds a small NNX model in CONVERT mode with int8, runs a forward
to populate qrhs.frozen via the ToNNX bridge, saves the
serve-mode-shape state to a tmp local orbax checkpoint, reloads via
from_pretrained(quant_mode_str="serve"), and asserts every saved
qrhs.frozen.qvalue array byte-matches what came back. Guards the
full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
test_quantize_passes_gate_for_nnx; added
test_load_pre_quantized_nnx_passes_quant_gate and
test_quantized_prefill_nnx_train_mode (real numerical verification
with quantization=int8 + random params + TRAIN mode).
End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.1 parent 02ff5f7 commit ca957ab
8 files changed
Lines changed: 637 additions & 57 deletions
File tree
- src/maxtext
- inference/maxengine
- models
- utils
- tests/unit
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
117 | 117 | | |
118 | 118 | | |
119 | 119 | | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
120 | 126 | | |
121 | 127 | | |
122 | 128 | | |
123 | 129 | | |
124 | 130 | | |
125 | | - | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
126 | 134 | | |
127 | | - | |
| 135 | + | |
128 | 136 | | |
| 137 | + | |
129 | 138 | | |
130 | 139 | | |
131 | 140 | | |
| |||
371 | 380 | | |
372 | 381 | | |
373 | 382 | | |
374 | | - | |
375 | | - | |
376 | | - | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
377 | 392 | | |
378 | 393 | | |
379 | 394 | | |
| |||
396 | 411 | | |
397 | 412 | | |
398 | 413 | | |
399 | | - | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
400 | 418 | | |
401 | | - | |
402 | | - | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
403 | 449 | | |
404 | 450 | | |
405 | | - | |
| 451 | + | |
406 | 452 | | |
407 | 453 | | |
408 | 454 | | |
| |||
485 | 531 | | |
486 | 532 | | |
487 | 533 | | |
488 | | - | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
489 | 544 | | |
490 | 545 | | |
491 | 546 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
| 31 | + | |
31 | 32 | | |
32 | 33 | | |
33 | 34 | | |
| |||
235 | 236 | | |
236 | 237 | | |
237 | 238 | | |
| 239 | + | |
238 | 240 | | |
239 | 241 | | |
240 | 242 | | |
| |||
252 | 254 | | |
253 | 255 | | |
254 | 256 | | |
| 257 | + | |
255 | 258 | | |
256 | 259 | | |
257 | 260 | | |
| |||
260 | 263 | | |
261 | 264 | | |
262 | 265 | | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
263 | 290 | | |
264 | 291 | | |
265 | 292 | | |
| |||
328 | 355 | | |
329 | 356 | | |
330 | 357 | | |
331 | | - | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
332 | 370 | | |
333 | 371 | | |
334 | 372 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
| 51 | + | |
50 | 52 | | |
51 | 53 | | |
52 | 54 | | |
| |||
164 | 166 | | |
165 | 167 | | |
166 | 168 | | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
172 | 178 | | |
173 | 179 | | |
174 | 180 | | |
175 | 181 | | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | 182 | | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
181 | 188 | | |
182 | 189 | | |
183 | 190 | | |
| |||
189 | 196 | | |
190 | 197 | | |
191 | 198 | | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
192 | 202 | | |
193 | 203 | | |
194 | 204 | | |
| |||
274 | 284 | | |
275 | 285 | | |
276 | 286 | | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
277 | 412 | | |
278 | 413 | | |
279 | 414 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1534 | 1534 | | |
1535 | 1535 | | |
1536 | 1536 | | |
1537 | | - | |
1538 | | - | |
| 1537 | + | |
| 1538 | + | |
| 1539 | + | |
| 1540 | + | |
| 1541 | + | |
| 1542 | + | |
| 1543 | + | |
| 1544 | + | |
| 1545 | + | |
| 1546 | + | |
| 1547 | + | |
1539 | 1548 | | |
1540 | 1549 | | |
1541 | 1550 | | |
| |||
0 commit comments