Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1ad2e0b
tests: pin Linen-only vocab tiling and pipeline tests for upcoming NN…
ecnal-cienet May 8, 2026
7190513
NNX: flip pure_nnx/enable_nnx/pure_nnx_decoder defaults to True
ecnal-cienet May 8, 2026
0005ac1
Temp: tests/unit/train_compile_test.py::TrainCompile::test_qk_clip_do…
hsuan-lun-chiang May 19, 2026
8fbeb6f
fix cpu UT failure
May 19, 2026
e6cb5c2
fix gpu UT failures
May 20, 2026
550cb1c
Fix tests/unit/muon_utils_test.py::TestGetMuonWeightDimensionNumbersN…
hsuan-lun-chiang May 20, 2026
4f4a278
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
hsuan-lun-chiang May 20, 2026
56896fb
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
hsuan-lun-chiang May 20, 2026
a4199de
Fix test compatibility with pure_nnx=True defaults
hsuan-lun-chiang May 20, 2026
020fdd5
Fix diloco related unit tests
hsuan-lun-chiang May 21, 2026
6f638ae
fix nnx_wrapper.py comment
May 21, 2026
4d06f28
fix nnx_wrapper.py gpu UT failure
May 21, 2026
b025274
Fix integration test failures under NNX defaults
ecnal-cienet May 21, 2026
0d93e2d
Revert fix for fp8
hsuan-lun-chiang May 22, 2026
16ac426
test: skip NNX int8 parameter-only checkpoint generation for GPU dot …
hsuan-lun-chiang May 25, 2026
42e8a32
Fix sft_llama3_demo_tpu.ipynb
hsuan-lun-chiang May 26, 2026
c3f8e27
test: skip fp8 cases under NNX (b/509790223)
ecnal-cienet May 26, 2026
81c9883
test: make maxengine prefill/cache tests NNX-only
ecnal-cienet May 27, 2026
1d404cf
fix: update train_state_nnx import path after #3929 relocation
ecnal-cienet May 28, 2026
bdda935
fix(nnx): support Zero-1 input shardings on NNX flat state
ecnal-cienet May 28, 2026
7f1dca3
Revert the incorrect Fp8 fix in nnx_decoders.py
hsuan-lun-chiang May 29, 2026
069c46f
NNX: fix DiLoCo train loop + checkpoint under pure_nnx
ecnal-cienet Jun 10, 2026
d26eb7c
NNX: fail fast for pipeline parallelism under pure_nnx
ecnal-cienet Jun 10, 2026
00902f3
[NNX] Delete Linen (1/4): collapse pure_nnx/enable_nnx/isinstance dis…
ecnal-cienet Jun 4, 2026
ce1849c
[NNX] Delete Linen (2/4): remove the Linen decoder stack and dead *_a…
ecnal-cienet Jun 4, 2026
15d896a
[NNX] Delete Linen (3/4): drop obsolete Linen tests and flag references
ecnal-cienet Jun 4, 2026
40575ff
[NNX] Delete Linen (4/4): remove the pure_nnx/enable_nnx/pure_nnx_dec…
ecnal-cienet Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"""

import argparse
import functools
import gc
import os
import sys
Expand All @@ -47,11 +46,7 @@
from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_PKG_DIR
from maxtext.common import checkpointing
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.layers import quantizations
from maxtext.common import train_state_nnx
from maxtext.models.models import transformer_as_linen
from maxtext.optimizers import optimizers
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
Expand Down Expand Up @@ -92,23 +87,15 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

if cfg.pure_nnx:
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
_, tx = train_utils.create_training_optimizer(cfg, model)
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
_, tx = train_utils.create_training_optimizer(cfg, model)
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)

def init_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

else:
quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
def init_state_fn():
nnx_model = _create_model_partial()
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
cfg.checkpoint_dir,
Expand Down Expand Up @@ -201,24 +188,18 @@ def init_state_fn():
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
}

if cfg.pure_nnx:
# NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
# nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
# `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
# `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
# model params -> ['model']<rest>.value
# adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
# step -> ['optimizer']['step'].value
# opt count -> ['optimizer']['opt_state'][0]['count'].value
state_map = {
"['optimizer']['step'].value": ("step", None),
"['optimizer']['opt_state'][0]['count'].value": ("opt_states_0.no_prefix_0.count", None),
}
else:
state_map = {
".step": ("step", None),
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
}
# NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
# nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
# `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
# `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
# model params -> ['model']<rest>.value
# adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
# step -> ['optimizer']['step'].value
# opt count -> ['optimizer']['opt_state'][0]['count'].value
state_map = {
"['optimizer']['step'].value": ("step", None),
"['optimizer']['opt_state'][0]['count'].value": ("opt_states_0.no_prefix_0.count", None),
}

def get_layer_prefix(keystr_pax):
# different path format between decoder_layer variable
Expand All @@ -231,26 +212,15 @@ def get_layer_prefix(keystr_pax):

for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
if cfg.pure_nnx:
state_map[f"['model']{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f"['optimizer']['opt_state'][0]['mu']{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f"['optimizer']['opt_state'][0]['nu']{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
else:
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)
state_map[f"['model']{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
state_map[f"['optimizer']['opt_state'][0]['mu']{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
transform_fn,
)
state_map[f"['optimizer']['opt_state'][0]['nu']{keystr_maxtext}.value"] = (
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
transform_fn,
)

def verify_fn(key_path, _):
keystr = jax.tree_util.keystr(key_path)
Expand Down Expand Up @@ -302,7 +272,7 @@ def map_fn(key_path, value):
max_logging.log("converted state finished")
max_utils.print_mem_stats("converted state finished")

step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
step_value = int(converted_state.optimizer.step.value)
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
max_logging.log(f"saved a checkpoint at step {step_value}")
# Upon preemption, exit when and only when all ongoing saves are complete.
Expand Down
31 changes: 22 additions & 9 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,23 +319,36 @@ def get_maxtext_model_info(config):
# Get abstract model structure (name, shape) without materializing the weights to save memory
abstract_params_tree = maxtext_utils.get_abstract_param(maxtext_model_flax, config)["params"]

abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_params_tree)
# Standardize abstract tree for later unflattening
abstract_params_tree = jax.tree.map(
lambda _: 0,
abstract_params_tree,
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
abstract_params_flat, abstract_params_treedef = jax.tree_util.tree_flatten_with_path(
abstract_params_tree, is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned)
)
abstract_params_treedef = jax.tree_util.tree_structure(abstract_params_tree)

max_logging.log("MaxText abstract model and state initialized.")

# preprocess state
maxtext_abstract_dict = {}
for mt_target_idx, (path_tuple, abstract_leaf_value) in enumerate(abstract_params_flat):
key_parts = [k.key for k in path_tuple if hasattr(k, "key")]
key_parts = []
for k in path_tuple:
# JAX path components can be DictKey(key), GetItemKey(key), or SequenceKey(idx).
# We prefer string keys. If we see an integer or digit-string index, we assume it's
# a layer/block index and join it with the previous part using '_', matching
# MaxText's Linen-style naming convention (e.g., layers_0).
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)

val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and key_parts:
key_parts[-1] = f"{key_parts[-1]}_{val_str}"
else:
key_parts.append(val_str)

mt_param_key = "params-" + "-".join(key_parts)
mt_target_shape = abstract_leaf_value.shape
if isinstance(abstract_leaf_value, nn.LogicallyPartitioned):
mt_target_shape = abstract_leaf_value.value.shape
else:
mt_target_shape = abstract_leaf_value.shape
maxtext_abstract_dict[mt_param_key] = (mt_target_idx, mt_target_shape)

return maxtext_abstract_dict, abstract_params_treedef
Expand Down
26 changes: 23 additions & 3 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,19 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = []
for k in path_tuple:
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)
val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and path_keys:
path_keys[-1] = f"{path_keys[-1]}_{val_str}"
else:
path_keys.append(val_str)

# Skip NNX RNG state variables (not model weights)
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
if "to_nnx__rngs" in path_keys or any(k == "rngs" or k.endswith("_rngs") for k in path_keys):
continue
# Skip if this is the "value" key itself - we want the parent path
if path_keys[-1] == "value":
Expand All @@ -932,7 +942,17 @@ def extract_linen_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = []
for k in path_tuple:
val = getattr(k, "key", getattr(k, "idx", None))
if val is None:
val = str(k)
val_str = str(val)
if (isinstance(val, int) or val_str.isdigit()) and path_keys:
path_keys[-1] = f"{path_keys[-1]}_{val_str}"
else:
path_keys.append(val_str)

# Construct maxtext_param_key from path_tuple
maxtext_param_key = "params-" + "-".join(path_keys)
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
Expand Down
47 changes: 23 additions & 24 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

import time
Expand Down Expand Up @@ -355,11 +356,17 @@ def combine_sharding(sds, shardings):
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
# NNX checkpoints are saved as a pure dict (see maybe_save_checkpoint), so the
# restore target must also be a pure dict. A boxed nnx.State would not match
# the on-disk tree.
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()
# Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
restore_args = jax.tree_util.tree_map(
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), restore_target
)
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
return ocp.Checkpointer(handler).restore(p, restore_target, restore_args=restore_args)


def create_orbax_checkpoint_manager(
Expand Down Expand Up @@ -838,9 +845,7 @@ def map_to_pspec(data):
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
):
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
with handle_checkpoint_mismatch(
"restore NNX checkpoint", checkpoint_path
):
with handle_checkpoint_mismatch("restore NNX checkpoint", checkpoint_path):
restored_nnx = _load_linen_checkpoint_into_nnx(
checkpoint_path,
abstract_unboxed_pre_state,
Expand Down Expand Up @@ -876,9 +881,7 @@ def map_to_pspec(data):
EmergencyReplicatorCheckpointManager,
),
):
restored = checkpoint_manager.restore(
step, args=Composite(state=checkpoint_args)
).state
restored = checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state
_assert_no_shaped_dtype_struct(restored)
return (
restored,
Expand Down Expand Up @@ -906,9 +909,7 @@ def map_to_pspec(data):
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
restored = checkpoint_manager.restore(
step, args=Composite(items=checkpoint_args)
)
restored = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args))
_assert_no_shaped_dtype_struct(restored)
return (restored, None)

Expand All @@ -918,9 +919,7 @@ def map_to_pspec(data):
else:
params = abstract_unboxed_pre_state.params

with handle_checkpoint_mismatch(
"load parameters", load_parameters_from_path
):
with handle_checkpoint_mismatch("load parameters", load_parameters_from_path):
restored_params = load_params_from_path(
load_parameters_from_path,
params,
Expand All @@ -932,9 +931,7 @@ def map_to_pspec(data):
return None, restored_params
elif load_full_state_from_path != "":
max_logging.log(f"Loading full state from path: {load_full_state_from_path}")
with handle_checkpoint_mismatch(
"load full state", load_full_state_from_path
):
with handle_checkpoint_mismatch("load full state", load_full_state_from_path):
restored_state = _load_full_state_from_path(
path=load_full_state_from_path,
abstract_unboxed_pre_state=abstract_unboxed_pre_state,
Expand Down Expand Up @@ -1033,18 +1030,20 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1
# Under DiLoCo the step lives on the DiLoCoTrainState; otherwise on the optimizer.
actual_step = int(state.step if config.enable_diloco else state.optimizer.step) - 1

if checkpoint_manager.latest_step() == actual_step:
max_logging.log(f"Checkpoint for step {actual_step} already exists, skipping save.")
return

if config.pure_nnx:
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
if config.enable_diloco:
# DiLoCoTrainState: persist the synchronized global model (outer params).
# The per-replica inner optimizer / outer-momentum state is not checkpointed.
step_value = state.step.get_value() if hasattr(state.step, "get_value") else state.step
state = train_state_nnx.to_linen_checkpoint_dict({"model": state.params, "optimizer": {"step": step_value}})
else:
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
Expand Down
5 changes: 0 additions & 5 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1181,11 +1181,6 @@ position_id_per_seconds: 25
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
pure_nnx: false

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
gdn_conv_kernel_dim: 4
Expand Down
2 changes: 0 additions & 2 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ base_config: "base.yml"
attention: "vllm_rpa"
model_call_mode: "inference"

# NNX required for vLLM integration
enable_nnx: true
# Avoid re-initializing JAX distributed system when using vLLM
skip_jax_distributed_system: true
# Scanned layers are not supported with vLLM integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ distill_alpha: 0.5
distill_temperature: 1.0
distill_beta: 0
distill_layer_indices: []
enable_nnx: True
load_balance_loss_weight: 0.001

# Megablox grouped-matmul m-tile (batch_seq). The k/n dims already default to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ distill_alpha: 0.6
distill_temperature: 1.0
distill_beta: 1.0
distill_layer_indices: [0,1,2,3,4,5,6,7]
enable_nnx: True
load_balance_loss_weight: 0.001

ici_fsdp_parallelism: -1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ distill_alpha: 0.6
distill_temperature: 1.0
distill_beta: 1.0
distill_layer_indices: [0,1,2,3,4,5,6,7]
enable_nnx: True
load_balance_loss_weight: 0.001

ici_fsdp_parallelism: -1
Expand Down
7 changes: 2 additions & 5 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
)


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int):
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")

Expand Down Expand Up @@ -238,9 +237,7 @@ def validate_keys(keys):
validate_model_call_mode(keys["model_call_mode"])
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
validate_rope_type(keys["rope_type"])
validate_vocab_tiling(
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
)
validate_vocab_tiling(keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"])
if keys["enable_rampup_batch_size"]:
validate_rampup_batch_size(
keys["per_device_batch_size_start"],
Expand Down
Loading
Loading