Commit c91a2a6
committed
NNX: finish MaxEngine inference carve-outs (multisampling, concat, stacked prefill cache)
PR7 (NNX-native MaxEngine inference) made the core prefill/generate/insert
path work under pure_nnx=True but left three serving features raising
NotImplementedError on the NNX path. This promotes all three to NNX-native.
Linen is preserved byte-for-byte: the original model.apply(..., mutable=["cache"])
calls are unchanged, just moved into else: branches, and every NNX edit is
gated `if config.pure_nnx:`.
maxengine.py:
- _prefill_multisampling_jit: drops the NotImplementedError; adds a pure_nnx
branch that runs prefill through _nnx_run_model (MODEL_MODE_PREFILL, batch=1)
with a fresh _nnx_init_cache_dict. The loop that draws num_samples first
tokens from the shared logits is unchanged.
- prefill_concat: same swap; the packed positions and segment ids thread
through _nnx_run_model unchanged.
- stack_prefill_result_cache=True: now supported for both scan_layers values.
scan_layers=True already stacks the per-layer KV cache on axis 0 (the Linen
post-stack shape), so _maybe_stack/_maybe_unstack_prefill_result_cache are
no-ops and prefill_kv_cache_shardings stays the full tree. scan_layers=False
keeps unstacked per-layer subtrees under cache["decoder"]["layers"][i] (int
keys), so _maybe_stack stacks them into one subtree with a leading layer axis,
_maybe_unstack splits it back into the int-keyed per-layer dict that
bulk_insert/_insert_jit walk, and _load_params_nnx prepends a layer axis to
each prefix-sharding spec (the NNX analog of the Linen P(None, *spec) +
["decoder"]["layers_0"] reshape).
tests/integration/maxengine_test.py:
- New _build_linen_params helper and a shared _stack_prefill_roundtrip helper.
- test_prefill_multisampling_nnx, test_prefill_concat_nnx: NNX vs Linen
result-shape parity, finite logits + cache.
- test_stack_prefill_result_cache_nnx (scan_layers=True) and
test_stack_prefill_result_cache_scan_layers_false_nnx (scan_layers=False):
prefill -> insert -> generate round-trip, layer-stacked leaves, finite
logits, next_pos advances.
Remaining NNX MaxEngine carve-outs are quantization (PR9) and LoRA (PR8),
which are other PRs' scope.1 parent 2f7f039 commit c91a2a6
2 files changed
Lines changed: 187 additions & 33 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
417 | 417 | | |
418 | 418 | | |
419 | 419 | | |
420 | | - | |
421 | | - | |
422 | | - | |
423 | | - | |
424 | | - | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
425 | 429 | | |
426 | 430 | | |
427 | 431 | | |
| |||
525 | 529 | | |
526 | 530 | | |
527 | 531 | | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
528 | 542 | | |
529 | 543 | | |
530 | 544 | | |
| |||
538 | 552 | | |
539 | 553 | | |
540 | 554 | | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
541 | 565 | | |
542 | 566 | | |
543 | 567 | | |
| |||
918 | 942 | | |
919 | 943 | | |
920 | 944 | | |
921 | | - | |
922 | | - | |
923 | | - | |
924 | 945 | | |
925 | 946 | | |
926 | 947 | | |
| |||
930 | 951 | | |
931 | 952 | | |
932 | 953 | | |
933 | | - | |
934 | | - | |
935 | | - | |
936 | | - | |
937 | | - | |
938 | | - | |
939 | | - | |
940 | | - | |
941 | | - | |
942 | | - | |
943 | | - | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
944 | 980 | | |
945 | 981 | | |
946 | 982 | | |
| |||
1046 | 1082 | | |
1047 | 1083 | | |
1048 | 1084 | | |
1049 | | - | |
1050 | | - | |
1051 | | - | |
1052 | 1085 | | |
1053 | 1086 | | |
1054 | 1087 | | |
1055 | 1088 | | |
1056 | 1089 | | |
1057 | 1090 | | |
1058 | | - | |
1059 | | - | |
1060 | | - | |
1061 | | - | |
1062 | | - | |
1063 | | - | |
1064 | | - | |
1065 | | - | |
1066 | | - | |
1067 | | - | |
1068 | | - | |
| 1091 | + | |
| 1092 | + | |
| 1093 | + | |
| 1094 | + | |
| 1095 | + | |
| 1096 | + | |
| 1097 | + | |
| 1098 | + | |
| 1099 | + | |
| 1100 | + | |
| 1101 | + | |
| 1102 | + | |
| 1103 | + | |
| 1104 | + | |
| 1105 | + | |
| 1106 | + | |
| 1107 | + | |
| 1108 | + | |
| 1109 | + | |
| 1110 | + | |
| 1111 | + | |
| 1112 | + | |
| 1113 | + | |
| 1114 | + | |
| 1115 | + | |
| 1116 | + | |
1069 | 1117 | | |
1070 | 1118 | | |
1071 | 1119 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
181 | 193 | | |
182 | 194 | | |
183 | 195 | | |
| |||
257 | 269 | | |
258 | 270 | | |
259 | 271 | | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
260 | 366 | | |
261 | 367 | | |
262 | 368 | | |
| |||
0 commit comments