Commit a56510e
committed
NNX: vocab tiling custom_vjp with output-head carve-out
Replaces the PR9.5 NNX vocab-tiling MVP (chunked forward + default
autograd backward) with a jax.custom_vjp that mirrors the Linen path's
backward-memory savings, then carves out the output-head params so the
custom_vjp's residuals + grad accumulator scale with LM-head size, not
with the full model. Linen vocab_tiling_linen_loss is byte-for-byte
unchanged. Call sites in train.py / pyconfig_deprecated.py /
configs/types.py are unchanged.
Custom_vjp + output-head carve-out (vocabulary_tiling.py):
- Outside the custom_vjp: 3-way nnx.split with a callable path filter
(_is_output_head_param_path) matching {token_embedder,
shared_embedding, decoder_norm, logits_dense} — the only nnx.Param
paths apply_output_head touches. Returns (graphdef, head_params,
other_params, rest).
- Custom_vjp primals: (head_params, other_params, rest, hidden_states,
labels, segmentation). Only head_params and hidden_states are
differentiated; other_params + rest are threaded through as
non-differentiated primals so their tracers don't have to cross both
the custom_vjp and the inner lax.scan boundary (which previously
caused UnexpectedTracerError under logits_via_embedding=True).
- Forward (_chunked_cross_entropy_loss_fwd): reshapes to
(num_vocab_tiling, vocab_tile_size, ...) and runs lax.scan whose body
rebuilds the model per chunk via nnx.merge(graphdef, chunk_head,
chunk_other, chunk_rest, copy=True) and calls
logits_from_hidden_states. Initial scan accumulator is fp32 (was
hidden_states.dtype previously — caused a lax.scan carry dtype
mismatch with bf16 hidden_states since cross_entropy_with_logits
always returns fp32). Residuals are (chunk_head, chunk_other,
chunk_rest, reshaped_*, batch/seq/emb).
- Backward (_chunked_cross_entropy_loss_bwd): a second lax.scan whose
body builds loss_fn_for_vjp = lambda p, h: ..., calls
jax.vjp(loss_fn_for_vjp, chunk_head_params, hidden_chunk),
accumulates grad_head via tree.map(add), emits per-chunk grad_hidden.
Chain-rules grad_head *= loss_cotangent and dtype-casts to each
primal's dtype (custom_vjp requires this). chunk_other_params and
chunk_rest cotangents are explicit tree_map(jnp.zeros_like, ...) zero
pytrees, NOT None — None makes JAX synthesize zeros at AOT trace time
with axis-0 stacking (jax.scan convention) for nnx.scan-stacked
transformer-layer params, which carry axis-1 stacking (nnx
convention), and the cotangent-shape check fails as
"Expected cotangent type bfloat16[E,M] for primal type bfloat16[E,M],
but got bfloat16[L,E,M]". Materializing the zeros ties the cotangent
shape to the primal shape exactly.
- Correctness: logits_from_hidden_states provably depends only on
head_params; the gradient w.r.t. other_params through this loss is
exactly zero. When train.py also calls the full model forward (which
produces hidden_states), transformer-layer gradients flow back
through grad_hidden_states → outer backward, unaffected by the
carve-out.
Supporting fixes (touched for the carve-out to work end-to-end):
- nnx_decoders.py::apply_output_head logits_via_embedding=True branch
reads embedding_table = shared_embedding.embedding[...] instead of
the deprecated .value shim. The .value shim registers the access in
NNX mutation tracking, which JAX detects as a tracer leak when the
embedding is closure-captured / threaded across the custom_vjp +
lax.scan boundaries. The Linen branch is unchanged.
- models.py: deletes dead-code self.hidden_states = None and
if num_vocab_tiling > 1: self.hidden_states = hidden_state from the
NNX Transformer class. Two lines left over from an early PR5
implementation idea — neither path actually reads
model.hidden_states (Linen reads via mutable=["intermediates"]; NNX
reads via nnx.pop(model, nnx.Intermediate) from the decoder's sown
("decoder", "hidden_states") intermediate). Without this fix, AOT
compile under pure_nnx=True + num_vocab_tiling>1 raised
ValueError: Cannot assign data value of type 'LinearizeTracer' to
static attribute 'hidden_states' of Pytree type 'Transformer' —
would have silently broken any post-PR11 user with vocab tiling on.
Tests (tiling_test.py — new VocabTilingNNXTest class with 9 TPU tests):
- test_nnx_vocab_tiling_non_tied_embedding / _tied_embedding: loss +
grad parity vs. full-vocab xent reference for both LM-head modes.
- test_nnx_vocab_tiling_total_z_loss_value_parity: asserts the second
tuple element matches the reference (was untested before).
- test_nnx_vocab_tiling_padded_segmentation: half-padded mask;
exercises the segmentation != 0 mask branch and asserts padded loss
is strictly less than unpadded.
- test_nnx_vocab_tiling_grad_over_hidden_states: argnums=1
differentiation; exercises the custom_vjp's second-primal cotangent
path (grad_reshaped_hidden_states), shape + dtype + value parity.
- test_nnx_vocab_tiling_bf16_hidden_states: bf16 inputs with rtol/atol
loosened to 5e-2; asserts grad_h.dtype == bf16 (the bwd dtype-cast
preserves the primal's dtype). Caught the fp32-accumulator bug.
- test_nnx_vocab_tiling_z_loss_zero: z_loss_multiplier=0;
total_z_loss == 0.0 exactly and grad parity holds.
- test_nnx_vocab_tiling_num_vocab_tiling_variants: runs n ∈ {2, 4, 8}
and asserts identical loss + grads (catches off-by-one in
vocab_tile_size and scan/reshape interactions).
- test_nnx_vocab_tiling_other_params_get_zero_grad (carve-out
invariant): asserts every non-head leaf has gradient exactly zero
AND at least one head leaf has non-zero gradient (so the test can't
trivially pass by zeroing everything). Catches filter bugs (e.g.
forgetting that NNX names the embedder token_embedder while Linen
names it shared_embedding) and bwd zero-shape bugs.
AOT compile coverage (train_compile_test.py):
- Removed the now-stale pytest.skip("Vocab tiling not supported on
NNX.") in test_vocab_tiling_bf16.
- Added test_vocab_tiling_bf16_nnx (cpu_only): AOT-compiles the train
step under pure_nnx=true + enable_nnx=true + pure_nnx_decoder=true
with num_vocab_tiling=4 and weight_dtype=bfloat16. Surfaced both the
models.py dead-code regression and the cotangent-axis-ordering issue
the explicit-zeros bwd fixes.
Tests pass: 18 in tiling + AOT (7 Linen UTs + 9 NNX UTs + 2 AOT, one
Linen and one NNX); 52 in adjacent NNX surfaces (train_nnx, dpo_nnx,
grpo_nnx, lora_utils_nnx, maxengine, qk_clip, aqt_serve_roundtrip_nnx)
— regression check for the nnx_decoders.py change.1 parent f9d94d1 commit a56510e
5 files changed
Lines changed: 507 additions & 48 deletions
File tree
- src/maxtext
- layers
- models
- utils
- tests/unit
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
984 | 984 | | |
985 | 985 | | |
986 | 986 | | |
987 | | - | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
988 | 990 | | |
989 | 991 | | |
990 | 992 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
347 | 347 | | |
348 | 348 | | |
349 | 349 | | |
350 | | - | |
351 | 350 | | |
352 | 351 | | |
353 | 352 | | |
| |||
541 | 540 | | |
542 | 541 | | |
543 | 542 | | |
544 | | - | |
545 | | - | |
546 | | - | |
547 | | - | |
548 | 543 | | |
549 | 544 | | |
550 | 545 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
33 | 45 | | |
34 | 46 | | |
35 | 47 | | |
| |||
253 | 265 | | |
254 | 266 | | |
255 | 267 | | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | | - | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
262 | 274 | | |
263 | 275 | | |
264 | 276 | | |
| |||
320 | 332 | | |
321 | 333 | | |
322 | 334 | | |
323 | | - | |
324 | | - | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
325 | 340 | | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | | - | |
344 | | - | |
345 | | - | |
346 | | - | |
347 | | - | |
348 | | - | |
349 | | - | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
350 | 350 | | |
| 351 | + | |
351 | 352 | | |
352 | | - | |
353 | | - | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
354 | 358 | | |
355 | | - | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
356 | 364 | | |
357 | | - | |
358 | | - | |
359 | | - | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
360 | 467 | | |
361 | 468 | | |
0 commit comments