Commit b525127
committed
NNX: vocab tiling custom_vjp with output-head carve-out
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default
autograd backward) with a jax.custom_vjp that mirrors the Linen path's
backward-memory savings, then carves out the output-head params so the
custom_vjp's residuals + grad accumulator scale with LM-head size, not
with the full model. Linen vocab_tiling_linen_loss is byte-for-byte
unchanged. Call sites in train.py / pyconfig_deprecated.py /
configs/types.py are unchanged.
Custom_vjp + output-head carve-out (vocabulary_tiling.py):
- Outside the custom_vjp: 3-way nnx.split with a callable path filter
(_is_output_head_param_path) matching {token_embedder,
shared_embedding, decoder_norm, logits_dense} — the only nnx.Param
paths apply_output_head touches. Returns (graphdef, head_params,
other_params, rest).
- Custom_vjp primals: (head_params, other_params, rest, hidden_states,
labels, segmentation). Only head_params and hidden_states are
differentiated; other_params + rest are threaded through as
non-differentiated primals so their tracers don't have to cross both
the custom_vjp and the inner lax.scan boundary (which previously
caused UnexpectedTracerError under logits_via_embedding=True).
- Forward (_chunked_cross_entropy_loss_fwd): reshapes to
(num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body
rebuilds the model per chunk via nnx.merge(graphdef, chunk_head,
chunk_other, chunk_rest, copy=True) and calls
logits_from_hidden_states. Initial scan accumulator is fp32 (was
hidden_states.dtype previously — caused a lax.scan carry dtype
mismatch with bf16 hidden_states since cross_entropy_with_logits
always returns fp32). Residuals are (chunk_head, chunk_other,
chunk_rest, reshaped_*, batch/seq/emb).
- Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose
body builds loss_fn_for_vjp = lambda p, h: ..., calls
jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk),
accumulates grad_head via tree.map(add), emits per-chunk grad_hidden.
Chain-rules grad_head *= loss_cotangent and dtype-casts to each
primal's dtype (custom_vjp requires this). chunk_other_params and
chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero
pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time
with axis-0 stacking (jax.scan convention) for nnx.scan-stacked
transformer-layer params, which carry axis-1 stacking (nnx
convention), and the cotangent-shape check fails as
"Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M],
but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent
shape to the primal shape exactly.
- Correctness: logits_from_hidden_states provably depends only on
head_params; the gradient w.r.t. other_params through this loss is
exactly zero. When train.py also calls the full model forward (which
produces hidden_states), transformer-layer gradients flow back
through grad_hidden_states → outer backward, unaffected by the
carve-out.
Supporting fixes (touched for the carve-out to work end-to-end):
- nnx_decoders.py::apply_output_head logits_via_embedding=True branch
reads embedding_table = shared_embedding.embedding[...] instead of
the deprecated .value shim. The .value shim registers the access in
NNX mutation tracking, which JAX detects as a tracer leak when the
embedding is closure-captured / threaded across the custom_vjp +
lax.scan boundaries. The Linen branch is unchanged.
- models.py: deletes dead-code self.hidden_states = None and
if num_vocab_tiling > 1: self.hidden_states = hidden_state from the
NNX Transformer class. Two lines left over from an early PR5
implementation idea — neither path actually reads
model.hidden_states (Linen reads via mutable=["intermediates"]; NNX
reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown
("decoder", "hidden_states") intermediate). Without this fix, AOT
compile under pure_nnx=True + num_vocab_tiling>1 raised
ValueError: Cannot assign data value of type 'LinearizeTracer' to
static attribute 'hidden_states' of Pytree type 'Transformer' —
would have silently broken any post-PR11 user with vocab tiling on.
Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests):
- test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss +
grad parity vs. full-vocab xent reference for both LM-head modes.
- test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second
tuple element matches the reference (was untested before).
- test_nnx_vocab_tiling_padded_segmentation: half-padded mask;
exercises the segmentation != 0 mask branch and asserts padded loss
is strictly less than unpadded.
- test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1
differentiation; exercises the custom_vjp's second-primal cotangent
path (grad_reshaped_hidden_states), shape + dtype + value parity.
- test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol
loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast
preserves the primal's dtype). Caught the fp32-accumulator bug.
- test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0;
total_z_loss == 0.0 exactly and grad parity holds.
- test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8}
and asserts identical loss + grads (catches off-by-one in
vocab_tile_size and scan/reshape interactions).
- test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out
invariant): asserts every non-head leaf has gradient exactly zero
AND at least one head leaf has non-zero gradient (so the test can't
trivially pass by zeroing everything). Catches filter bugs (e.g.
forgetting that NNX names the embedder token_embedder while Linen
names it shared_embedding) and bwd zero-shape bugs.
AOT compile coverage (train_compile_test.py):
- Removed the now-stale pytest.skip("Vocab tiling not supported on
NNX.") in test_vocab_tiling_bf16.
- Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train
step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true
with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the
models.py dead-code regression and the cotangent-axis-ordering issue
the explicit-zeros bwd fixes.
Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one
Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx,
grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx)
— regression check for the nnx_decoders.py change.1 parent 88417d0 commit b525127
5 files changed
Lines changed: 538 additions & 44 deletions
File tree
- src/maxtext
- layers
- models
- utils
- tests/unit
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
987 | 987 | | |
988 | 988 | | |
989 | 989 | | |
990 | | - | |
| 990 | + | |
| 991 | + | |
| 992 | + | |
| 993 | + | |
991 | 994 | | |
992 | 995 | | |
993 | 996 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
347 | 347 | | |
348 | 348 | | |
349 | 349 | | |
350 | | - | |
351 | 350 | | |
352 | 351 | | |
353 | 352 | | |
| |||
545 | 544 | | |
546 | 545 | | |
547 | 546 | | |
548 | | - | |
549 | | - | |
550 | | - | |
551 | | - | |
552 | 547 | | |
553 | 548 | | |
554 | 549 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| |||
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
32 | 52 | | |
33 | 53 | | |
34 | 54 | | |
| |||
252 | 272 | | |
253 | 273 | | |
254 | 274 | | |
255 | | - | |
256 | | - | |
257 | | - | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
258 | 283 | | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
263 | 287 | | |
264 | 288 | | |
265 | 289 | | |
| |||
321 | 345 | | |
322 | 346 | | |
323 | 347 | | |
324 | | - | |
325 | | - | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
326 | 363 | | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | | - | |
344 | | - | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
345 | 379 | | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
346 | 395 | | |
347 | | - | |
348 | | - | |
| 396 | + | |
| 397 | + | |
349 | 398 | | |
350 | | - | |
| 399 | + | |
351 | 400 | | |
352 | | - | |
353 | | - | |
354 | | - | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
355 | 488 | | |
356 | 489 | | |
0 commit comments